]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Multi-class classification project baseline
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 20 Jul 2025 16:11:52 +0000 (17:11 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 20 Jul 2025 16:11:52 +0000 (17:11 +0100)
18 files changed:
lualib/lua_bayes_redis.lua
lualib/redis_scripts/bayes_cache_learn.lua
lualib/redis_scripts/bayes_classify.lua
lualib/redis_scripts/bayes_learn.lua
src/libserver/cfg_file.h
src/libserver/cfg_rcl.cxx
src/libserver/cfg_utils.cxx
src/libserver/task.c
src/libserver/task.h
src/libstat/backends/cdb_backend.cxx
src/libstat/backends/mmaped_file.c
src/libstat/backends/redis_backend.cxx
src/libstat/backends/sqlite3_backend.c
src/libstat/classifiers/bayes.c
src/libstat/classifiers/classifiers.h
src/libstat/stat_api.h
src/libstat/stat_config.c
src/libstat/stat_process.c

index 782e6fc4729602234bf975d6cb54d987533bf443..59952131ab0d9a0c7869090bb2eab04e8d397b62 100644 (file)
@@ -25,25 +25,56 @@ local ucl = require "ucl"
 local N = "bayes"
 
 local function gen_classify_functor(redis_params, classify_script_id)
-  return function(task, expanded_key, id, is_spam, stat_tokens, callback)
-
+  return function(task, expanded_key, id, class_labels, stat_tokens, callback)
     local function classify_redis_cb(err, data)
       lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data)
       if err then
         callback(task, false, err)
       else
-        callback(task, true, data[1], data[2], data[3], data[4])
+        -- Handle both binary and multi-class results
+        if type(data[1]) == "table" then
+          -- Multi-class format: [learned_counts_table, outputs_table]
+          -- Convert to binary format for backward compatibility if needed
+          local learned_counts = data[1]
+          local outputs = data[2]
+
+          -- For now, return ham/spam data if available for backward compatibility
+          local learned_ham = learned_counts["H"] or learned_counts["ham"] or 0
+          local learned_spam = learned_counts["S"] or learned_counts["spam"] or 0
+          local output_ham = outputs["H"] or outputs["ham"] or {}
+          local output_spam = outputs["S"] or outputs["spam"] or {}
+
+          callback(task, true, learned_ham, learned_spam, output_ham, output_spam)
+        else
+          -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
+          callback(task, true, data[1], data[2], data[3], data[4])
+        end
+      end
+    end
+
+    -- Determine class labels to send to Redis script
+    local script_class_labels
+    if type(class_labels) == "table" then
+      script_class_labels = class_labels
+    else
+      -- Single class label or boolean compatibility
+      if class_labels == true or class_labels == "true" then
+        script_class_labels = "S"          -- spam
+      elseif class_labels == false or class_labels == "false" then
+        script_class_labels = "H"          -- ham
+      else
+        script_class_labels = class_labels -- string class label
       end
     end
 
     lua_redis.exec_redis_script(classify_script_id,
-        { task = task, is_write = false, key = expanded_key },
-        classify_redis_cb, { expanded_key, stat_tokens })
+      { task = task, is_write = false, key = expanded_key },
+      classify_redis_cb, { expanded_key, script_class_labels, stat_tokens })
   end
 end
 
 local function gen_learn_functor(redis_params, learn_script_id)
-  return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens)
+  return function(task, expanded_key, id, class_label, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens)
     local function learn_redis_cb(err, data)
       lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data)
       if err then
@@ -53,17 +84,24 @@ local function gen_learn_functor(redis_params, learn_script_id)
       end
     end
 
+    -- Convert class_label for backward compatibility
+    local script_class_label = class_label
+    if class_label == true or class_label == "true" then
+      script_class_label = "S" -- spam
+    elseif class_label == false or class_label == "false" then
+      script_class_label = "H" -- ham
+    end
+
     if maybe_text_tokens then
       lua_redis.exec_redis_script(learn_script_id,
-          { task = task, is_write = true, key = expanded_key },
-          learn_redis_cb,
-          { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
+        { task = task, is_write = true, key = expanded_key },
+        learn_redis_cb,
+        { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
     else
       lua_redis.exec_redis_script(learn_script_id,
-          { task = task, is_write = true, key = expanded_key },
-          learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
+        { task = task, is_write = true, key = expanded_key },
+        learn_redis_cb, { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens })
     end
-
   end
 end
 
@@ -112,8 +150,7 @@ end
 --- @param classifier_ucl ucl of the classifier config
 --- @param statfile_ucl ucl of the statfile config
 --- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
-exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)
-
+exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, class_label, ev_base, stat_periodic_cb)
   local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
 
   if not redis_params then
@@ -137,7 +174,6 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
 
   if ev_base then
     rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
-
       local function stat_redis_cb(err, data)
         lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)
 
@@ -162,12 +198,23 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
         end
       end
 
+      -- Convert class_label to learn key
+      local learn_key
+      if class_label == true or class_label == "true" or class_label == "S" then
+        learn_key = "learns_spam"
+      elseif class_label == false or class_label == "false" or class_label == "H" then
+        learn_key = "learns_ham"
+      else
+        -- For other class labels, use learns_<class_label>
+        learn_key = "learns_" .. string.lower(tostring(class_label))
+      end
+
       lua_redis.exec_redis_script(stat_script_id,
-          { ev_base = ev_base, cfg = cfg, is_write = false },
-          stat_redis_cb, { tostring(cursor),
-                           symbol,
-                           is_spam and "learns_spam" or "learns_ham",
-                           tostring(max_users) })
+        { ev_base = ev_base, cfg = cfg, is_write = false },
+        stat_redis_cb, { tostring(cursor),
+          symbol,
+          learn_key,
+          tostring(max_users) })
       return statfile_ucl.monitor_timeout or classifier_ucl.monitor_timeout or 30.0
     end)
   end
@@ -178,7 +225,6 @@ end
 local function gen_cache_check_functor(redis_params, check_script_id, conf)
   local packed_conf = ucl.to_format(conf, 'msgpack')
   return function(task, cache_id, callback)
-
     local function classify_redis_cb(err, data)
       lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data))
       if err then
@@ -194,24 +240,33 @@ local function gen_cache_check_functor(redis_params, check_script_id, conf)
 
     lua_util.debugm(N, task, 'checking cache: %s', cache_id)
     lua_redis.exec_redis_script(check_script_id,
-        { task = task, is_write = false, key = cache_id },
-        classify_redis_cb, { cache_id, packed_conf })
+      { task = task, is_write = false, key = cache_id },
+      classify_redis_cb, { cache_id, packed_conf })
   end
 end
 
 local function gen_cache_learn_functor(redis_params, learn_script_id, conf)
   local packed_conf = ucl.to_format(conf, 'msgpack')
-  return function(task, cache_id, is_spam)
+  return function(task, cache_id, class_name)
     local function learn_redis_cb(err, data)
       lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
     end
 
-    lua_util.debugm(N, task, 'try to learn cache: %s', cache_id)
-    lua_redis.exec_redis_script(learn_script_id,
-        { task = task, is_write = true, key = cache_id },
-        learn_redis_cb,
-        { cache_id, is_spam and "1" or "0", packed_conf })
+    -- Handle backward compatibility for boolean values
+    local cache_class_name = class_name
+    if type(class_name) == "boolean" then
+      cache_class_name = class_name and "spam" or "ham"
+    elseif class_name == true or class_name == "true" then
+      cache_class_name = "spam"
+    elseif class_name == false or class_name == "false" then
+      cache_class_name = "ham"
+    end
 
+    lua_util.debugm(N, task, 'try to learn cache: %s as %s', cache_id, cache_class_name)
+    lua_redis.exec_redis_script(learn_script_id,
+      { task = task, is_write = true, key = cache_id },
+      learn_redis_cb,
+      { cache_id, cache_class_name, packed_conf })
   end
 end
 
@@ -225,8 +280,8 @@ exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl)
   local default_conf = {
     cache_prefix = "learned_ids",
     cache_max_elt = 10000, -- Maximum number of elements in the cache key
-    cache_max_keys = 5, -- Maximum number of keys in the cache
-    cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value)
+    cache_max_keys = 5,    -- Maximum number of keys in the cache
+    cache_elt_len = 32,    -- Length of the element in the cache (will trim id to that value)
   }
 
   local conf = lua_util.override_defaults(default_conf, classifier_ucl)
@@ -241,7 +296,7 @@ exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl)
   local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params)
 
   return gen_cache_check_functor(redis_params, check_script_id, conf), gen_cache_learn_functor(redis_params,
-      learn_script_id, conf)
+    learn_script_id, conf)
 end
 
 return exports
index 7d44a73efc7ed7854ec7c6fd527b988fc913ab25..b3e15a9bc3fdfd4ffdc92f1307ddf8930f42788f 100644 (file)
@@ -1,12 +1,19 @@
--- Lua script to perform cache checking for bayes classification
+-- Lua script to perform cache checking for bayes classification (multi-class)
 -- This script accepts the following parameters:
 -- key1 - cache id
--- key2 - is spam (1 or 0)
+-- key2 - class name (e.g. "spam", "ham", "transactional")
 -- key3 - configuration table in message pack
 
 local cache_id = KEYS[1]
-local is_spam = KEYS[2]
+local class_name = KEYS[2]
 local conf = cmsgpack.unpack(KEYS[3])
+
+-- Handle backward compatibility for binary values
+if class_name == "1" then
+  class_name = "spam"
+elseif class_name == "0" then
+  class_name = "ham"
+end
 cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
 
 -- Try each prefix that is in Redis (as some other instance might have set it)
@@ -15,8 +22,8 @@ for i = 0, conf.cache_max_keys do
   local have = redis.call('HGET', prefix, cache_id)
 
   if have then
-    -- Already in cache, but is_spam changes when relearning
-    redis.call('HSET', prefix, cache_id, is_spam)
+    -- Already in cache, but class_name changes when relearning
+    redis.call('HSET', prefix, cache_id, class_name)
     return false
   end
 end
@@ -30,7 +37,7 @@ for i = 0, conf.cache_max_keys do
 
     if count < lim then
       -- We can add it to this prefix
-      redis.call('HSET', prefix, cache_id, is_spam)
+      redis.call('HSET', prefix, cache_id, class_name)
       added = true
     end
   end
@@ -46,7 +53,7 @@ if not added then
     if exists then
       if not expired then
         redis.call('DEL', prefix)
-        redis.call('HSET', prefix, cache_id, is_spam)
+        redis.call('HSET', prefix, cache_id, class_name)
 
         -- Do not expire anything else
         expired = true
index e94f645fdf8e7a934af1e95b1e8e46e6d3fd00e3..8e6feb32f80be340c015aa0d67f744a6709dbcf3 100644 (file)
@@ -1,37 +1,90 @@
--- Lua script to perform bayes classification
+-- Lua script to perform bayes classification (multi-class)
 -- This script accepts the following parameters:
 -- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - set of tokens encoded in messagepack array of strings
+-- key2 - class labels: either table of all class labels (multi-class) or single string (binary)
+-- key3 - set of tokens encoded in messagepack array of strings
 
 local prefix = KEYS[1]
-local output_spam = {}
-local output_ham = {}
+local class_labels_arg = KEYS[2]
+local input_tokens = cmsgpack.unpack(KEYS[3])
 
-local learned_ham = tonumber(redis.call('HGET', prefix, 'learns_ham')) or 0
-local learned_spam = tonumber(redis.call('HGET', prefix, 'learns_spam')) or 0
+-- Determine if this is multi-class (table) or binary (string)
+local class_labels = {}
+if type(class_labels_arg) == "table" then
+  class_labels = class_labels_arg
+else
+  -- Binary compatibility: handle old boolean or single string format
+  if class_labels_arg == "true" then
+    class_labels = { "S" }            -- spam
+  elseif class_labels_arg == "false" then
+    class_labels = { "H" }            -- ham
+  else
+    class_labels = { class_labels_arg } -- single class label
+  end
+end
 
--- Output is a set of pairs (token_index, token_count), tokens that are not
--- found are not filled.
--- This optimisation will save a lot of space for sparse tokens, and in Bayes that assumption is normally held
+-- Get learned counts for all classes
+local learned_counts = {}
+for _, label in ipairs(class_labels) do
+  local key = 'learns_' .. string.lower(label)
+  -- Also try legacy keys for backward compatibility
+  if label == 'H' then
+    key = 'learns_ham'
+  elseif label == 'S' then
+    key = 'learns_spam'
+  end
+  learned_counts[label] = tonumber(redis.call('HGET', prefix, key)) or 0
+end
 
-if learned_ham > 0 and learned_spam > 0 then
-  local input_tokens = cmsgpack.unpack(KEYS[2])
-  for i, token in ipairs(input_tokens) do
-    local token_data = redis.call('HMGET', token, 'H', 'S')
+-- Get token data for all classes (only if we have learns for any class)
+local outputs = {}
+local has_learns = false
+for _, count in pairs(learned_counts) do
+  if count > 0 then
+    has_learns = true
+    break
+  end
+end
 
-    if token_data then
-      local ham_count = token_data[1]
-      local spam_count = token_data[2]
+if has_learns then
+  -- Initialize outputs for each class
+  for _, label in ipairs(class_labels) do
+    outputs[label] = {}
+  end
 
-      if ham_count then
-        table.insert(output_ham, { i, tonumber(ham_count) })
-      end
+  -- Process each token
+  for i, token in ipairs(input_tokens) do
+    local token_data = redis.call('HMGET', token, unpack(class_labels))
 
-      if spam_count then
-        table.insert(output_spam, { i, tonumber(spam_count) })
+    if token_data then
+      for j, label in ipairs(class_labels) do
+        local count = token_data[j]
+        if count then
+          table.insert(outputs[label], { i, tonumber(count) })
+        end
       end
     end
   end
 end
 
-return { learned_ham, learned_spam, output_ham, output_spam }
\ No newline at end of file
+-- Format output for backward compatibility
+if #class_labels == 2 and class_labels[1] == 'H' and class_labels[2] == 'S' then
+  -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
+  return {
+    learned_counts['H'] or 0,
+    learned_counts['S'] or 0,
+    outputs['H'] or {},
+    outputs['S'] or {}
+  }
+elseif #class_labels == 2 and class_labels[1] == 'S' and class_labels[2] == 'H' then
+  -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
+  return {
+    learned_counts['H'] or 0,
+    learned_counts['S'] or 0,
+    outputs['H'] or {},
+    outputs['S'] or {}
+  }
+else
+  -- Multi-class format: [learned_counts_table, outputs_table]
+  return { learned_counts, outputs }
+end
index 5456165b6949e55ed18e2dcb63d21f93880afa37..b284a2812326e5e8b6ed26bb27c8f7927e6d0a7a 100644 (file)
@@ -1,14 +1,14 @@
--- Lua script to perform bayes learning
+-- Lua script to perform bayes learning (multi-class)
 -- This script accepts the following parameters:
 -- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - boolean is_spam
+-- key2 - class label string (e.g. "S", "H", "T")
 -- key3 - string symbol
 -- key4 - boolean is_unlearn
 -- key5 - set of tokens encoded in messagepack array of strings
 -- key6 - set of text tokens (if any) encoded in messagepack array of strings (size must be twice of `KEYS[5]`)
 
 local prefix = KEYS[1]
-local is_spam = KEYS[2] == 'true' and true or false
+local class_label = KEYS[2]
 local symbol = KEYS[3]
 local is_unlearn = KEYS[4] == 'true' and true or false
 local input_tokens = cmsgpack.unpack(KEYS[5])
@@ -18,11 +18,25 @@ if KEYS[6] then
   text_tokens = cmsgpack.unpack(KEYS[6])
 end
 
-local hash_key = is_spam and 'S' or 'H'
-local learned_key = is_spam and 'learns_spam' or 'learns_ham'
+-- Handle backward compatibility for boolean values
+if class_label == 'true' then
+  class_label = 'S' -- spam
+elseif class_label == 'false' then
+  class_label = 'H' -- ham
+end
+
+local hash_key = class_label
+local learned_key = 'learns_' .. string.lower(class_label)
+
+-- Handle legacy keys for backward compatibility
+if class_label == 'S' then
+  learned_key = 'learns_spam'
+elseif class_label == 'H' then
+  learned_key = 'learns_ham'
+end
 
 redis.call('SADD', symbol .. '_keys', prefix)
-redis.call('HSET', prefix, 'version', '2') -- new schema
+redis.call('HSET', prefix, 'version', '2')                         -- new schema
 redis.call('HINCRBY', prefix, learned_key, is_unlearn and -1 or 1) -- increase or decrease learned count
 
 for i, token in ipairs(input_tokens) do
index 36941da7ad97d52cc2a26f8e4102276fd95d413e..cd2ab43141061c037ddec08a91cb7a6f4f2f5b59 100644 (file)
@@ -139,7 +139,8 @@ struct rspamd_statfile_config {
        char *symbol;                          /**< symbol of statfile                                                                  */
        char *label;                           /**< label of this statfile                                                              */
        ucl_object_t *opts;                    /**< other options                                                                               */
-       gboolean is_spam;                      /**< spam flag                                                                                   */
+       char *class_name;                      /**< class name for multi-class classification                   */
+       gboolean is_spam;                      /**< DEPRECATED: spam flag - use class_name instead              */
        struct rspamd_classifier_config *clcf; /**< parent pointer of classifier configuration                  */
        gpointer data;                         /**< opaque data                                                                                 */
 };
@@ -182,6 +183,8 @@ struct rspamd_classifier_config {
        double min_prob_strength;                  /**< use only tokens with probability in [0.5 - MPS, 0.5 + MPS] */
        unsigned int min_learns;                   /**< minimum number of learns for each statfile                      */
        unsigned int flags;
+       GHashTable *class_labels; /**< class_name -> backend_symbol mapping for multi-class */
+       GPtrArray *class_names;   /**< ordered list of class names                                              */
 };
 
 struct rspamd_worker_bind_conf {
@@ -621,12 +624,25 @@ void rspamd_config_insert_classify_symbols(struct rspamd_config *cfg);
  */
 gboolean rspamd_config_check_statfiles(struct rspamd_classifier_config *cf);
 
-/*
- * Find classifier config by name
+/**
+ * Multi-class configuration helpers
+ */
+gboolean rspamd_config_parse_class_labels(ucl_object_t *obj,
+                                                                                 GHashTable **class_labels);
+
+gboolean rspamd_config_migrate_binary_config(struct rspamd_statfile_config *stcf);
+
+gboolean rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf,
+                                                                                        GError **err);
+
+const char *rspamd_config_get_class_label(struct rspamd_classifier_config *ccf,
+                                                                                 const char *class_name);
+
+/**
+ * Find classifier by name
  */
 struct rspamd_classifier_config *rspamd_config_find_classifier(
-       struct rspamd_config *cfg,
-       const char *name);
+       struct rspamd_config *cfg, const char *name);
 
 void rspamd_ucl_add_conf_macros(struct ucl_parser *parser,
                                                                struct rspamd_config *cfg);
index 0a48e8a4f435b43602498a6e3a4ea8cdf8e36abb..3f0a9606a24ef30a2b1a9b2223e9418f40388ae2 100644 (file)
@@ -1197,37 +1197,84 @@ rspamd_rcl_statfile_handler(rspamd_mempool_t *pool, const ucl_object_t *obj,
                st->opts = (ucl_object_t *) obj;
                st->clcf = ccf;
 
-               const auto *val = ucl_object_lookup(obj, "spam");
-               if (val == nullptr) {
+               /* Handle migration from old 'spam' field to new 'class' field */
+               const auto *class_val = ucl_object_lookup(obj, "class");
+               const auto *spam_val = ucl_object_lookup(obj, "spam");
+
+               if (class_val != nullptr && spam_val != nullptr) {
+                       msg_warn_config("statfile %s has both 'class' and 'spam' fields, using 'class' field",
+                                                       st->symbol);
+               }
+
+               if (class_val == nullptr && spam_val == nullptr) {
+                       /* Neither field present, try to guess by symbol name */
                        msg_info_config(
-                               "statfile %s has no explicit 'spam' setting, trying to guess by symbol",
+                               "statfile %s has no explicit 'class' or 'spam' setting, trying to guess by symbol",
                                st->symbol);
                        if (rspamd_substring_search_caseless(st->symbol,
                                                                                                 strlen(st->symbol), "spam", 4) != -1) {
                                st->is_spam = TRUE;
+                               st->class_name = rspamd_mempool_strdup(pool, "spam");
                        }
                        else if (rspamd_substring_search_caseless(st->symbol,
                                                                                                          strlen(st->symbol), "ham", 3) != -1) {
                                st->is_spam = FALSE;
+                               st->class_name = rspamd_mempool_strdup(pool, "ham");
                        }
                        else {
                                g_set_error(err,
                                                        CFG_RCL_ERROR,
                                                        EINVAL,
-                                                       "cannot guess spam setting from %s",
+                                                       "cannot guess class setting from %s, please specify 'class' field",
                                                        st->symbol);
                                return FALSE;
                        }
-                       msg_info_config("guessed that statfile with symbol %s is %s",
-                                                       st->symbol,
-                                                       st->is_spam ? "spam" : "ham");
+                       msg_info_config("guessed that statfile with symbol %s has class '%s'",
+                                                       st->symbol, st->class_name);
                }
+               else if (class_val == nullptr && spam_val != nullptr) {
+                       /* Only spam field present - migrate to class */
+                       msg_warn_config("statfile %s uses deprecated 'spam' field, please use 'class' instead",
+                                                       st->symbol);
+                       if (st->is_spam) {
+                               st->class_name = rspamd_mempool_strdup(pool, "spam");
+                       }
+                       else {
+                               st->class_name = rspamd_mempool_strdup(pool, "ham");
+                       }
+               }
+               /* If class field is present, it was already parsed by the default parser */
                return TRUE;
        }
 
        return FALSE;
 }
 
+static gboolean
+rspamd_rcl_class_labels_handler(rspamd_mempool_t *pool,
+                                                               const ucl_object_t *obj,
+                                                               const char *key,
+                                                               gpointer ud,
+                                                               struct rspamd_rcl_section *section,
+                                                               GError **err)
+{
+       auto *ccf = static_cast<rspamd_classifier_config *>(ud);
+
+       if (obj->type != UCL_OBJECT) {
+               g_set_error(err, CFG_RCL_ERROR, EINVAL,
+                                       "class_labels must be an object");
+               return FALSE;
+       }
+
+       if (!rspamd_config_parse_class_labels((ucl_object_t *) obj, &ccf->class_labels)) {
+               g_set_error(err, CFG_RCL_ERROR, EINVAL,
+                                       "invalid class_labels configuration");
+               return FALSE;
+       }
+
+       return TRUE;
+}
+
 static gboolean
 rspamd_rcl_classifier_handler(rspamd_mempool_t *pool,
                                                          const ucl_object_t *obj,
@@ -1375,6 +1422,21 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool,
        }
 
        ccf->opts = (ucl_object_t *) obj;
+
+       /* Validate multi-class configuration */
+       GError *validation_err = nullptr;
+       if (!rspamd_config_validate_class_config(ccf, &validation_err)) {
+               if (validation_err) {
+                       g_propagate_error(err, validation_err);
+               }
+               else {
+                       g_set_error(err, CFG_RCL_ERROR, EINVAL,
+                                               "multi-class configuration validation failed for classifier '%s'",
+                                               ccf->name ? ccf->name : "unknown");
+               }
+               return FALSE;
+       }
+
        cfg->classifiers = g_list_prepend(cfg->classifiers, ccf);
 
        return TRUE;
@@ -2504,6 +2566,18 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections)
                                                                           0,
                                                                           "Name of classifier");
 
+               /*
+                * Multi-class configuration
+                */
+               rspamd_rcl_add_section_doc(&top, sub,
+                                                                  "class_labels", nullptr,
+                                                                  rspamd_rcl_class_labels_handler,
+                                                                  UCL_OBJECT,
+                                                                  FALSE,
+                                                                  TRUE,
+                                                                  sub->doc_ref,
+                                                                  "Class to backend label mapping for multi-class classification");
+
                /*
                 * Statfile defaults
                 */
@@ -2521,12 +2595,18 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections)
                                                                           G_STRUCT_OFFSET(struct rspamd_statfile_config, label),
                                                                           0,
                                                                           "Statfile unique label");
+               rspamd_rcl_add_default_handler(ssub,
+                                                                          "class",
+                                                                          rspamd_rcl_parse_struct_string,
+                                                                          G_STRUCT_OFFSET(struct rspamd_statfile_config, class_name),
+                                                                          0,
+                                                                          "Class name for multi-class classification");
                rspamd_rcl_add_default_handler(ssub,
                                                                           "spam",
                                                                           rspamd_rcl_parse_struct_boolean,
                                                                           G_STRUCT_OFFSET(struct rspamd_statfile_config, is_spam),
                                                                           0,
-                                                                          "Sets if this statfile contains spam samples");
+                                                                          "DEPRECATED: Sets if this statfile contains spam samples (use 'class' instead)");
        }
 
        if (!(skip_sections && g_hash_table_lookup(skip_sections, "composite"))) {
index c7bb202108dd01095d9fb309621128aef65fb357..c8c08343970a689efae191db3d3776ccded50fae 100644 (file)
@@ -3042,3 +3042,169 @@ rspamd_ip_is_local_cfg(struct rspamd_config *cfg,
 
        return FALSE;
 }
+
+gboolean
+rspamd_config_parse_class_labels(ucl_object_t *obj, GHashTable **class_labels)
+{
+       const ucl_object_t *cur;
+       ucl_object_iter_t it = nullptr;
+       const char *class_name, *label;
+
+       if (!obj || ucl_object_type(obj) != UCL_OBJECT) {
+               return FALSE;
+       }
+
+       *class_labels = g_hash_table_new_full(g_str_hash, g_str_equal, g_free, g_free);
+
+       while ((cur = ucl_object_iterate(obj, &it, true)) != nullptr) {
+               class_name = ucl_object_key(cur);
+               label = ucl_object_tostring(cur);
+
+               if (class_name && label) {
+                       /* Validate class name: alphanumeric + underscore, max 32 chars */
+                       if (strlen(class_name) > 32) {
+                               msg_err("class name '%s' is too long (max 32 characters)", class_name);
+                               g_hash_table_destroy(*class_labels);
+                               *class_labels = nullptr;
+                               return FALSE;
+                       }
+
+                       for (const char *p = class_name; *p; p++) {
+                               if (!g_ascii_isalnum(*p) && *p != '_') {
+                                       msg_err("class name '%s' contains invalid character '%c'", class_name, *p);
+                                       g_hash_table_destroy(*class_labels);
+                                       *class_labels = nullptr;
+                                       return FALSE;
+                               }
+                       }
+
+                       /* Validate label uniqueness */
+                       GHashTableIter label_iter;
+                       gpointer key, value;
+                       g_hash_table_iter_init(&label_iter, *class_labels);
+                       while (g_hash_table_iter_next(&label_iter, &key, &value)) {
+                               if (strcmp((const char *) value, label) == 0) {
+                                       msg_err("backend label '%s' is used by multiple classes", label);
+                                       g_hash_table_destroy(*class_labels);
+                                       *class_labels = nullptr;
+                                       return FALSE;
+                               }
+                       }
+
+                       g_hash_table_insert(*class_labels, g_strdup(class_name), g_strdup(label));
+               }
+       }
+
+       return g_hash_table_size(*class_labels) > 0;
+}
+
+gboolean
+rspamd_config_migrate_binary_config(struct rspamd_statfile_config *stcf)
+{
+       if (stcf->class_name != nullptr) {
+               /* Already migrated or using new format */
+               return TRUE;
+       }
+
+       if (stcf->is_spam) {
+               stcf->class_name = g_strdup("spam");
+               msg_info("migrated statfile '%s' from is_spam=true to class='spam'",
+                                stcf->symbol ? stcf->symbol : "unknown");
+       }
+       else {
+               stcf->class_name = g_strdup("ham");
+               msg_info("migrated statfile '%s' from is_spam=false to class='ham'",
+                                stcf->symbol ? stcf->symbol : "unknown");
+       }
+
+       return TRUE;
+}
+
+gboolean
+rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf, GError **err)
+{
+       GList *cur;
+       GHashTable *seen_classes = nullptr;
+       struct rspamd_statfile_config *stcf;
+       unsigned int class_count = 0;
+
+       if (!ccf || !ccf->statfiles) {
+               g_set_error(err, g_quark_from_static_string("config"), 1,
+                                       "classifier has no statfiles defined");
+               return FALSE;
+       }
+
+       seen_classes = g_hash_table_new_full(g_str_hash, g_str_equal, g_free, nullptr);
+
+       /* Iterate through statfiles and collect classes */
+       cur = ccf->statfiles;
+       while (cur) {
+               stcf = (struct rspamd_statfile_config *) cur->data;
+
+               /* Migrate binary config if needed */
+               if (!rspamd_config_migrate_binary_config(stcf)) {
+                       g_set_error(err, g_quark_from_static_string("config"), 1,
+                                               "failed to migrate binary config for statfile '%s'",
+                                               stcf->symbol ? stcf->symbol : "unknown");
+                       g_hash_table_destroy(seen_classes);
+                       return FALSE;
+               }
+
+               /* Check class name */
+               if (!stcf->class_name || strlen(stcf->class_name) == 0) {
+                       g_set_error(err, g_quark_from_static_string("config"), 1,
+                                               "statfile '%s' has no class defined",
+                                               stcf->symbol ? stcf->symbol : "unknown");
+                       g_hash_table_destroy(seen_classes);
+                       return FALSE;
+               }
+
+               /* Track unique classes */
+               if (!g_hash_table_contains(seen_classes, stcf->class_name)) {
+                       g_hash_table_insert(seen_classes, g_strdup(stcf->class_name), GINT_TO_POINTER(1));
+                       class_count++;
+               }
+
+               cur = g_list_next(cur);
+       }
+
+       /* Validate class count */
+       if (class_count < 2) {
+               g_set_error(err, g_quark_from_static_string("config"), 1,
+                                       "classifier must have at least 2 classes, found %u", class_count);
+               g_hash_table_destroy(seen_classes);
+               return FALSE;
+       }
+
+       if (class_count > 20) {
+               msg_warn("classifier has %u classes, performance may be degraded above 20 classes",
+                                class_count);
+       }
+
+       /* Initialize classifier class tracking */
+       if (ccf->class_names) {
+               g_ptr_array_unref(ccf->class_names);
+       }
+       ccf->class_names = g_ptr_array_new_with_free_func(g_free);
+
+       /* Populate class names array */
+       GHashTableIter iter;
+       gpointer key, value;
+       g_hash_table_iter_init(&iter, seen_classes);
+       while (g_hash_table_iter_next(&iter, &key, &value)) {
+               g_ptr_array_add(ccf->class_names, g_strdup((const char *) key));
+       }
+
+       g_hash_table_destroy(seen_classes);
+       return TRUE;
+}
+
+const char *
+rspamd_config_get_class_label(struct rspamd_classifier_config *ccf, const char *class_name)
+{
+       if (!ccf || !ccf->class_labels || !class_name) {
+               return nullptr;
+       }
+
+       return (const char *) g_hash_table_lookup(ccf->class_labels, class_name);
+}
index 9f5b1f00a1b08b31e606cffcb3c27e2e1b903c19..e0435828461a09f58da70a56f7f735ffc3558b61 100644 (file)
@@ -730,7 +730,7 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages)
 
                if (all_done && (task->flags & RSPAMD_TASK_FLAG_LEARN_AUTO) &&
                        !RSPAMD_TASK_IS_EMPTY(task) &&
-                       !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM))) {
+                       !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS))) {
                        rspamd_stat_check_autolearn(task);
                }
                break;
@@ -738,12 +738,32 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages)
        case RSPAMD_TASK_STAGE_LEARN:
        case RSPAMD_TASK_STAGE_LEARN_PRE:
        case RSPAMD_TASK_STAGE_LEARN_POST:
-               if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM)) {
+               if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS)) {
                        if (task->err == NULL) {
-                               if (!rspamd_stat_learn(task,
-                                                                          task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM,
-                                                                          task->cfg->lua_state, task->classifier,
-                                                                          st, &stat_error)) {
+                               gboolean learn_result = FALSE;
+
+                               if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) {
+                                       /* Multi-class learning */
+                                       const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+                                       if (autolearn_class) {
+                                               learn_result = rspamd_stat_learn_class(task, autolearn_class,
+                                                                                                                          task->cfg->lua_state, task->classifier,
+                                                                                                                          st, &stat_error);
+                                       }
+                                       else {
+                                               g_set_error(&stat_error, g_quark_from_static_string("stat"), 500,
+                                                                       "No autolearn class specified for multi-class learning");
+                                       }
+                               }
+                               else {
+                                       /* Legacy binary learning */
+                                       learn_result = rspamd_stat_learn(task,
+                                                                                                        task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM,
+                                                                                                        task->cfg->lua_state, task->classifier,
+                                                                                                        st, &stat_error);
+                               }
+
+                               if (!learn_result) {
 
                                        if (stat_error == NULL) {
                                                g_set_error(&stat_error,
index 1c1778fee4d740dd511ca1bbed9723efc57d6f0b..a1742e16084e750b3771914d2358a03de5b190a0 100644 (file)
@@ -104,9 +104,9 @@ enum rspamd_task_stage {
 #define RSPAMD_TASK_FLAG_LEARN_SPAM (1u << 12u)
 #define RSPAMD_TASK_FLAG_LEARN_HAM (1u << 13u)
 #define RSPAMD_TASK_FLAG_LEARN_AUTO (1u << 14u)
+#define RSPAMD_TASK_FLAG_LEARN_CLASS (1u << 25u)
 #define RSPAMD_TASK_FLAG_BROKEN_HEADERS (1u << 15u)
-#define RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS (1u << 16u)
-#define RSPAMD_TASK_FLAG_HAS_HAM_TOKENS (1u << 17u)
+/* Removed RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS and RSPAMD_TASK_FLAG_HAS_HAM_TOKENS - not needed in multi-class */
 #define RSPAMD_TASK_FLAG_EMPTY (1u << 18u)
 #define RSPAMD_TASK_FLAG_PROFILE (1u << 19u)
 #define RSPAMD_TASK_FLAG_GREYLISTED (1u << 20u)
@@ -114,7 +114,7 @@ enum rspamd_task_stage {
 #define RSPAMD_TASK_FLAG_SSL (1u << 22u)
 #define RSPAMD_TASK_FLAG_BAD_UNICODE (1u << 23u)
 #define RSPAMD_TASK_FLAG_MESSAGE_REWRITE (1u << 24u)
-#define RSPAMD_TASK_FLAG_MAX_SHIFT (24u)
+#define RSPAMD_TASK_FLAG_MAX_SHIFT (25u)
 
 /* Request has been done by a local client */
 #define RSPAMD_TASK_PROTOCOL_FLAG_LOCAL_CLIENT (1u << 1u)
index 0f55a725c422134aae9fc5bf81144d194e32bc2e..f6ca9c12d80ae04b5172398bf6c426b04e3e6112 100644 (file)
@@ -393,7 +393,6 @@ rspamd_cdb_process_tokens(struct rspamd_task *task,
                                                  gpointer runtime)
 {
        auto *cdbp = CDB_FROM_RAW(runtime);
-       bool seen_values = false;
 
        for (auto i = 0u; i < tokens->len; i++) {
                rspamd_token_t *tok;
@@ -403,21 +402,13 @@ rspamd_cdb_process_tokens(struct rspamd_task *task,
 
                if (res) {
                        tok->values[id] = res.value();
-                       seen_values = true;
                }
                else {
                        tok->values[id] = 0;
                }
        }
 
-       if (seen_values) {
-               if (cdbp->is_spam()) {
-                       task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
-               }
-               else {
-                       task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
-               }
-       }
+       /* No longer need to set flags - multi-class handles missing data naturally */
 
        return true;
 }
@@ -488,4 +479,4 @@ void rspamd_cdb_close(gpointer ctx)
 {
        auto *cdbp = CDB_FROM_RAW(ctx);
        delete cdbp;
-}
\ No newline at end of file
+}
index 4430bb9a439392c737551ac6b503614e3df94fd1..a6423a1e6cdb689158d10f98d739ae1892b4e806 100644 (file)
@@ -85,8 +85,7 @@ typedef struct {
 
 #define RSPAMD_STATFILE_VERSION \
        {                           \
-               '1', '2'                \
-       }
+               '1', '2'}
 #define BACKUP_SUFFIX ".old"
 
 static void rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool,
@@ -958,12 +957,7 @@ rspamd_mmaped_file_process_tokens(struct rspamd_task *task, GPtrArray *tokens,
                tok->values[id] = rspamd_mmaped_file_get_block(mf, h1, h2);
        }
 
-       if (mf->cf->is_spam) {
-               task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
-       }
-       else {
-               task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
-       }
+       /* No longer need to set flags - multi-class handles missing data naturally */
 
        return TRUE;
 }
index 7137904e992de536b276373e0c9eb715625b83a0..01ed818c47376bc4aac2127a626f33732094ab59 100644 (file)
@@ -121,9 +121,9 @@ public:
        }
 
        static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded,
-                                                                                  bool is_spam) -> std::optional<redis_stat_runtime<T> *>
+                                                                                  const char *class_label) -> std::optional<redis_stat_runtime<T> *>
        {
-               auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+               auto var_name = fmt::format("{}_{}", redis_object_expanded, class_label);
                auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str());
 
                if (res) {
@@ -158,9 +158,9 @@ public:
                return true;
        }
 
-       auto save_in_mempool(bool is_spam) const
+       auto save_in_mempool(const char *class_label) const
        {
-               auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+               auto var_name = fmt::format("{}_{}", redis_object_expanded, class_label);
                /* We do not set destructor for the variable, as it should be already added on creation */
                rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr);
                msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str());
@@ -177,6 +177,26 @@ rspamd_redis_stat_quark(void)
        return g_quark_from_static_string(M);
 }
 
+/*
+ * Get the class label for a statfile (for multi-class support)
+ */
+static const char *
+get_class_label(struct rspamd_statfile_config *stcf)
+{
+       /* Try to get the label from the classifier config first */
+       if (stcf->clcf && stcf->clcf->class_labels && stcf->class_name) {
+               const char *label = rspamd_config_get_class_label(stcf->clcf, stcf->class_name);
+               if (label) {
+                       return label;
+               }
+               /* If no label mapping found, use class name directly */
+               return stcf->class_name;
+       }
+
+       /* Fallback to legacy binary classification */
+       return stcf->is_spam ? "S" : "H";
+}
+
 /*
  * Non-static for lua unit testing
  */
@@ -541,7 +561,7 @@ rspamd_redis_init(struct rspamd_stat_ctx *ctx,
        ucl_object_push_lua(L, st->classifier->cfg->opts, false);
        ucl_object_push_lua(L, st->stcf->opts, false);
        lua_pushstring(L, backend->stcf->symbol);
-       lua_pushboolean(L, backend->stcf->is_spam);
+       lua_pushstring(L, get_class_label(backend->stcf)); /* Pass class label instead of boolean */
 
        /* Push event loop if there is one available (e.g. we are not in rspamadm mode) */
        if (ctx->event_loop) {
@@ -607,10 +627,12 @@ rspamd_redis_runtime(struct rspamd_task *task,
                return nullptr;
        }
 
+       const char *class_label = get_class_label(stcf);
+
        /* Look for the cached results */
        if (!learn) {
                auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
-                                                                                                                                                                       object_expanded, stcf->is_spam);
+                                                                                                                                                                       object_expanded, class_label);
 
                if (maybe_existing) {
                        auto *rt = maybe_existing.value();
@@ -624,24 +646,45 @@ rspamd_redis_runtime(struct rspamd_task *task,
        /* No cached result (or learn), create new one */
        auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
 
-       if (!learn) {
+       if (!learn && stcf->clcf && stcf->clcf->class_names && stcf->clcf->class_names->len > 2) {
                /*
-                * For check, we also need to create the opposite class runtime to avoid
-                * double call for Redis scripts.
-                * This runtime will be filled later.
+                * For multi-class classification, we need to create runtimes for ALL classes
+                * to avoid multiple Redis calls. The actual Redis call will fetch data for all classes.
                 */
+               GList *cur = stcf->clcf->statfiles;
+               while (cur) {
+                       auto *other_stcf = (struct rspamd_statfile_config *) cur->data;
+                       if (other_stcf != stcf) {
+                               const char *other_label = get_class_label(other_stcf);
+
+                               auto maybe_other_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+                                                                                                                                                                                       object_expanded, other_label);
+                               if (!maybe_other_rt) {
+                                       auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+                                       other_rt->save_in_mempool(other_label);
+                                       other_rt->need_redis_call = false;
+                               }
+                       }
+                       cur = g_list_next(cur);
+               }
+       }
+       else if (!learn) {
+               /*
+                * For binary classification, create the opposite class runtime to avoid
+                * double call for Redis scripts (backward compatibility).
+                */
+               const char *opposite_label = stcf->is_spam ? "H" : "S";
                auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
-                                                                                                                                                                          object_expanded,
-                                                                                                                                                                          !stcf->is_spam);
+                                                                                                                                                                          object_expanded, opposite_label);
 
                if (!maybe_opposite_rt) {
                        auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
-                       opposite_rt->save_in_mempool(!stcf->is_spam);
+                       opposite_rt->save_in_mempool(opposite_label);
                        opposite_rt->need_redis_call = false;
                }
        }
 
-       rt->save_in_mempool(stcf->is_spam);
+       rt->save_in_mempool(class_label);
 
        return rt;
 }
@@ -823,16 +866,10 @@ rspamd_redis_classified(lua_State *L)
        bool result = lua_toboolean(L, 2);
 
        if (result) {
-               /* Indexes:
-                * 3 - learned_ham (int)
-                * 4 - learned_spam (int)
-                * 5 - ham_tokens (pair<int, int>)
-                * 6 - spam_tokens (pair<int, int>)
+               /* Check if this is binary format [learned_ham, learned_spam, ham_tokens, spam_tokens]
+                * or multi-class format [learned_counts_table, outputs_table]
                 */
 
-               /*
-                * We need to fill our runtime AND the opposite runtime
-                */
                auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
                        rt->learned = learned;
                        redis_stat_runtime<float>::result_type *res;
@@ -854,32 +891,96 @@ rspamd_redis_classified(lua_State *L)
                        rt->set_results(res);
                };
 
-               auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
-                                                                                                                                                                          rt->redis_object_expanded,
-                                                                                                                                                                          !rt->stcf->is_spam);
+               /* Check if result[3] is a number (binary) or table (multi-class) */
+               lua_rawgeti(L, 3, 1); /* Get first element of result array */
+               bool is_binary_format = lua_isnumber(L, -1);
+               lua_pop(L, 1);
 
-               if (!opposite_rt_maybe) {
-                       msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
+               if (is_binary_format) {
+                       /* Binary format: [learned_ham, learned_spam, ham_tokens, spam_tokens] */
 
-                       return 0;
-               }
+                       /* Find the opposite runtime for binary classification compatibility */
+                       const char *opposite_label;
+                       if (rt->stcf->class_name) {
+                               /* Multi-class: find a different class (simplified for now) */
+                               opposite_label = strcmp(get_class_label(rt->stcf), "S") == 0 ? "H" : "S";
+                       }
+                       else {
+                               /* Binary: use opposite spam/ham */
+                               opposite_label = rt->stcf->is_spam ? "H" : "S";
+                       }
+                       auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+                                                                                                                                                                                  rt->redis_object_expanded,
+                                                                                                                                                                                  opposite_label);
+
+                       if (!opposite_rt_maybe) {
+                               msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
+                               return 0;
+                       }
+
+                       if (rt->stcf->is_spam) {
+                               filler_func(rt, L, lua_tointeger(L, 4), 6);
+                               filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+                       }
+                       else {
+                               filler_func(rt, L, lua_tointeger(L, 3), 5);
+                               filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
+                       }
 
-               if (rt->stcf->is_spam) {
-                       filler_func(rt, L, lua_tointeger(L, 4), 6);
-                       filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+                       /* Process all tokens */
+                       g_assert(rt->tokens != nullptr);
+                       rt->process_tokens(rt->tokens);
+                       opposite_rt_maybe.value()->process_tokens(rt->tokens);
                }
                else {
-                       filler_func(rt, L, lua_tointeger(L, 3), 5);
-                       filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
-               }
+                       /* Multi-class format: [learned_counts_table, outputs_table] */
+
+                       /* Get learned counts table (index 3) and outputs table (index 4) */
+                       lua_rawgeti(L, 3, 1); /* learned_counts */
+                       lua_rawgeti(L, 3, 2); /* outputs */
+
+                       /* Iterate through all class labels to fill all runtimes */
+                       if (rt->stcf->clcf && rt->stcf->clcf->class_labels) {
+                               GHashTableIter iter;
+                               gpointer key, value;
+                               g_hash_table_iter_init(&iter, rt->stcf->clcf->class_labels);
+
+                               while (g_hash_table_iter_next(&iter, &key, &value)) {
+                                       const char *class_label = (const char *) value;
+
+                                       /* Find runtime for this class */
+                                       auto class_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+                                                                                                                                                                                               rt->redis_object_expanded,
+                                                                                                                                                                                               class_label);
+
+                                       if (class_rt_maybe) {
+                                               auto *class_rt = class_rt_maybe.value();
+
+                                               /* Get learned count for this class */
+                                               lua_pushstring(L, class_label);
+                                               lua_gettable(L, -3); /* learned_counts[class_label] */
+                                               unsigned learned = lua_tointeger(L, -1);
+                                               lua_pop(L, 1);
+
+                                               /* Get outputs for this class */
+                                               lua_pushstring(L, class_label);
+                                               lua_gettable(L, -2); /* outputs[class_label] */
+                                               int outputs_pos = lua_gettop(L);
+
+                                               filler_func(class_rt, L, learned, outputs_pos);
+                                               lua_pop(L, 1);
+                                       }
+                               }
+                       }
+
+                       lua_pop(L, 2); /* Pop learned_counts and outputs tables */
 
-               /* Mark task as being processed */
-               task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+                       /* Process tokens for all runtimes */
+                       g_assert(rt->tokens != nullptr);
+                       rt->process_tokens(rt->tokens);
+               }
 
-               /* Process all tokens */
-               g_assert(rt->tokens != nullptr);
-               rt->process_tokens(rt->tokens);
-               opposite_rt_maybe.value()->process_tokens(rt->tokens);
+               /* Tokens processed - no need to set flags in multi-class approach */
        }
        else {
                /* Error message is on index 3 */
@@ -929,7 +1030,25 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
        rspamd_lua_task_push(L, task);
        lua_pushstring(L, rt->redis_object_expanded);
        lua_pushinteger(L, id);
-       lua_pushboolean(L, rt->stcf->is_spam);
+
+       /* Send all class labels for multi-class support */
+       if (rt->stcf->clcf && rt->stcf->clcf->class_labels && g_hash_table_size(rt->stcf->clcf->class_labels) > 0) {
+               /* Multi-class: send array of all class labels */
+               lua_newtable(L);
+               GHashTableIter iter;
+               gpointer key, value;
+               int idx = 1;
+               g_hash_table_iter_init(&iter, rt->stcf->clcf->class_labels);
+               while (g_hash_table_iter_next(&iter, &key, &value)) {
+                       lua_pushstring(L, (const char *) value); /* Use the label, not class name */
+                       lua_rawseti(L, -2, idx++);
+               }
+       }
+       else {
+               /* Binary compatibility: send current class label as single string */
+               lua_pushstring(L, get_class_label(rt->stcf));
+       }
+
        lua_new_text(L, tokens_buf, tokens_len, false);
 
        /* Store rt in random cookie */
@@ -1028,7 +1147,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
        rspamd_lua_task_push(L, task);
        lua_pushstring(L, rt->redis_object_expanded);
        lua_pushinteger(L, id);
-       lua_pushboolean(L, rt->stcf->is_spam);
+       lua_pushstring(L, get_class_label(rt->stcf)); /* Pass class label instead of boolean */
        lua_pushstring(L, rt->stcf->symbol);
 
        /* Detect unlearn */
index 973dc30a7609bb25a812d4aa98ccbca852d49128..8f29a3b4ed19158355d8df7272589c3d53442566 100644 (file)
@@ -589,12 +589,7 @@ rspamd_sqlite3_process_tokens(struct rspamd_task *task,
                        }
                }
 
-               if (rt->cf->is_spam) {
-                       task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
-               }
-               else {
-                       task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
-               }
+               /* No longer need to set flags - multi-class handles missing data naturally */
        }
 
 
index 93b5149dad05d6ece9c2558ba0c58a1efcda6e1c..4a1b0cf32add147bc81b47d51588ce6c5b2d8610 100644 (file)
@@ -94,8 +94,8 @@ inv_chi_square(struct rspamd_task *task, double value, int freedom_deg)
 }
 
 struct bayes_task_closure {
-       double ham_prob;
-       double spam_prob;
+       double ham_prob;  /* Kept for binary compatibility */
+       double spam_prob; /* Kept for binary compatibility */
        double meta_skip_prob;
        uint64_t processed_tokens;
        uint64_t total_hits;
@@ -103,6 +103,20 @@ struct bayes_task_closure {
        struct rspamd_task *task;
 };
 
+/* Multi-class classification closure */
+struct bayes_multiclass_closure {
+       double *class_log_probs;  /* Array of log probabilities for each class */
+       uint64_t *class_learns;   /* Learning counts for each class */
+       char **class_names;       /* Array of class names */
+       unsigned int num_classes; /* Number of classes */
+       double meta_skip_prob;
+       uint64_t processed_tokens;
+       uint64_t total_hits;
+       uint64_t text_tokens;
+       struct rspamd_task *task;
+       struct rspamd_classifier_config *cfg;
+};
+
 /*
  * Mathematically we use pow(complexity, complexity), where complexity is the
  * window index
@@ -248,6 +262,301 @@ bayes_classify_token(struct rspamd_classifier *ctx,
        }
 }
 
+/*
+ * Multinomial token classification for multi-class Bayes
+ */
+static void
+bayes_classify_token_multiclass(struct rspamd_classifier *ctx,
+                                                               rspamd_token_t *tok,
+                                                               struct bayes_multiclass_closure *cl)
+{
+       unsigned int i, j;
+       int id;
+       struct rspamd_statfile *st;
+       struct rspamd_task *task;
+       const char *token_type = "txt";
+       double val, fw, w;
+       unsigned int *class_counts;
+       unsigned int total_count = 0;
+
+       task = cl->task;
+
+       /* Skip meta tokens probabilistically if configured */
+       if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) {
+               val = rspamd_random_double_fast();
+               if (val <= cl->meta_skip_prob) {
+                       return;
+               }
+               token_type = "meta";
+       }
+
+       /* Allocate array for class counts */
+       class_counts = g_alloca(cl->num_classes * sizeof(unsigned int));
+       memset(class_counts, 0, cl->num_classes * sizeof(unsigned int));
+
+       /* Collect counts for each class */
+       for (i = 0; i < ctx->statfiles_ids->len; i++) {
+               id = g_array_index(ctx->statfiles_ids, int, i);
+               st = g_ptr_array_index(ctx->ctx->statfiles, id);
+               g_assert(st != NULL);
+               val = tok->values[id];
+
+               if (val > 0) {
+                       /* Find which class this statfile belongs to */
+                       for (j = 0; j < cl->num_classes; j++) {
+                               if (st->stcf->class_name &&
+                                       strcmp(st->stcf->class_name, cl->class_names[j]) == 0) {
+                                       class_counts[j] += val;
+                                       total_count += val;
+                                       cl->total_hits += val;
+                                       break;
+                               }
+                       }
+               }
+       }
+
+       /* Calculate multinomial probability for this token */
+       if (total_count >= ctx->cfg->min_token_hits) {
+               /* Feature weight calculation */
+               if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) {
+                       fw = 1.0;
+               }
+               else {
+                       fw = feature_weight[tok->window_idx % G_N_ELEMENTS(feature_weight)];
+               }
+
+               w = (fw * total_count) / (1.0 + fw * total_count);
+
+               /* Apply multinomial model for each class */
+               for (j = 0; j < cl->num_classes; j++) {
+                       double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]);
+                       double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes);
+
+                       /* Skip probabilities too close to uniform (1/num_classes) */
+                       double uniform_prior = 1.0 / cl->num_classes;
+                       if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) {
+                               continue;
+                       }
+
+                       cl->class_log_probs[j] += log(class_prob);
+               }
+
+               cl->processed_tokens++;
+               if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
+                       cl->text_tokens++;
+               }
+
+               if (tok->t1 && tok->t2) {
+                       msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %.3f, total_count: %ud, "
+                                                       "processed for %u classes",
+                                                       token_type, tok->data,
+                                                       (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+                                                       (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+                                                       fw, total_count, cl->num_classes);
+               }
+       }
+}
+
+/*
+ * Multinomial Bayes classification with Fisher confidence
+ */
+static gboolean
+bayes_classify_multiclass(struct rspamd_classifier *ctx,
+                                                 GPtrArray *tokens,
+                                                 struct rspamd_task *task)
+{
+       struct bayes_multiclass_closure cl;
+       rspamd_token_t *tok;
+       unsigned int i, j, text_tokens = 0;
+       int id;
+       struct rspamd_statfile *st;
+       rspamd_multiclass_result_t *result;
+       double *normalized_probs;
+       double max_log_prob = -INFINITY;
+       unsigned int winning_class_idx = 0;
+       double confidence;
+
+       g_assert(ctx != NULL);
+       g_assert(tokens != NULL);
+
+       /* Initialize multi-class closure */
+       memset(&cl, 0, sizeof(cl));
+       cl.task = task;
+       cl.cfg = ctx->cfg;
+
+       /* Get class information from classifier config */
+       if (!ctx->cfg->class_names || ctx->cfg->class_names->len < 2) {
+               msg_debug_bayes("insufficient classes for multiclass classification");
+               return TRUE; /* Fall back to binary mode */
+       }
+
+       cl.num_classes = ctx->cfg->class_names->len;
+       cl.class_names = (char **) ctx->cfg->class_names->pdata;
+       cl.class_log_probs = g_alloca(cl.num_classes * sizeof(double));
+       cl.class_learns = g_alloca(cl.num_classes * sizeof(uint64_t));
+
+       /* Initialize probabilities and get learning counts */
+       for (i = 0; i < cl.num_classes; i++) {
+               cl.class_log_probs[i] = 0.0;
+               cl.class_learns[i] = 0;
+       }
+
+       /* Collect learning counts for each class */
+       for (i = 0; i < ctx->statfiles_ids->len; i++) {
+               id = g_array_index(ctx->statfiles_ids, int, i);
+               st = g_ptr_array_index(ctx->ctx->statfiles, id);
+               g_assert(st != NULL);
+
+               for (j = 0; j < cl.num_classes; j++) {
+                       if (st->stcf->class_name &&
+                               strcmp(st->stcf->class_name, cl.class_names[j]) == 0) {
+                               cl.class_learns[j] += st->backend->total_learns(task,
+                                                                                                                               g_ptr_array_index(task->stat_runtimes, id), ctx->ctx);
+                               break;
+                       }
+               }
+       }
+
+       /* Check minimum learns requirement */
+       if (ctx->cfg->min_learns > 0) {
+               for (i = 0; i < cl.num_classes; i++) {
+                       if (cl.class_learns[i] < ctx->cfg->min_learns) {
+                               msg_info_task("not classified as %s. The class needs more "
+                                                         "training samples. Currently: %ul; minimum %ud required",
+                                                         cl.class_names[i], cl.class_learns[i], ctx->cfg->min_learns);
+                               return TRUE;
+                       }
+               }
+       }
+
+       /* Count text tokens */
+       for (i = 0; i < tokens->len; i++) {
+               tok = g_ptr_array_index(tokens, i);
+               if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
+                       text_tokens++;
+               }
+       }
+
+       if (text_tokens == 0) {
+               msg_info_task("skipped classification as there are no text tokens. "
+                                         "Total tokens: %ud",
+                                         tokens->len);
+               return TRUE;
+       }
+
+       /* Set meta token skip probability */
+       if (text_tokens > tokens->len - text_tokens) {
+               cl.meta_skip_prob = 0.0;
+       }
+       else {
+               cl.meta_skip_prob = 1.0 - (double) text_tokens / tokens->len;
+       }
+
+       /* Process all tokens */
+       for (i = 0; i < tokens->len; i++) {
+               tok = g_ptr_array_index(tokens, i);
+               bayes_classify_token_multiclass(ctx, tok, &cl);
+       }
+
+       if (cl.processed_tokens == 0) {
+               msg_info_bayes("no tokens found in bayes database "
+                                          "(%ud total tokens, %ud text tokens), ignore stats",
+                                          tokens->len, text_tokens);
+               return TRUE;
+       }
+
+       if (ctx->cfg->min_tokens > 0 &&
+               cl.text_tokens < (int) (ctx->cfg->min_tokens * 0.1)) {
+               msg_info_bayes("ignore bayes probability since we have "
+                                          "found too few text tokens: %uL (of %ud checked), "
+                                          "at least %d required",
+                                          cl.text_tokens, text_tokens,
+                                          (int) (ctx->cfg->min_tokens * 0.1));
+               return TRUE;
+       }
+
+       /* Normalize probabilities using softmax */
+       normalized_probs = g_alloca(cl.num_classes * sizeof(double));
+
+       /* Find maximum for numerical stability */
+       for (i = 0; i < cl.num_classes; i++) {
+               if (cl.class_log_probs[i] > max_log_prob) {
+                       max_log_prob = cl.class_log_probs[i];
+                       winning_class_idx = i;
+               }
+       }
+
+       /* Apply softmax normalization */
+       double sum_exp = 0.0;
+       for (i = 0; i < cl.num_classes; i++) {
+               normalized_probs[i] = exp(cl.class_log_probs[i] - max_log_prob);
+               sum_exp += normalized_probs[i];
+       }
+
+       if (sum_exp > 0) {
+               for (i = 0; i < cl.num_classes; i++) {
+                       normalized_probs[i] /= sum_exp;
+               }
+       }
+       else {
+               /* Fallback to uniform distribution */
+               for (i = 0; i < cl.num_classes; i++) {
+                       normalized_probs[i] = 1.0 / cl.num_classes;
+               }
+       }
+
+       /* Calculate confidence using Fisher method for the winning class */
+       if (max_log_prob > -300) {
+               confidence = 1.0 - inv_chi_square(task, max_log_prob, cl.processed_tokens);
+       }
+       else {
+               confidence = normalized_probs[winning_class_idx];
+       }
+
+       /* Create and store multiclass result */
+       result = g_new0(rspamd_multiclass_result_t, 1);
+       result->class_names = g_new(char *, cl.num_classes);
+       result->probabilities = g_new(double, cl.num_classes);
+       result->num_classes = cl.num_classes;
+       result->winning_class = cl.class_names[winning_class_idx]; /* Reference, not copy */
+       result->confidence = confidence;
+
+       for (i = 0; i < cl.num_classes; i++) {
+               result->class_names[i] = g_strdup(cl.class_names[i]);
+               result->probabilities[i] = normalized_probs[i];
+       }
+
+       rspamd_task_set_multiclass_result(task, result);
+
+       /* Insert symbol for winning class if confidence is significant */
+       if (confidence > 0.05) {
+               char sumbuf[32];
+               double final_prob = rspamd_normalize_probability(confidence, 0.5);
+
+               rspamd_snprintf(sumbuf, sizeof(sumbuf), "%.2f%%", confidence * 100.0);
+
+               /* Find the statfile for the winning class to get the symbol */
+               for (i = 0; i < ctx->statfiles_ids->len; i++) {
+                       id = g_array_index(ctx->statfiles_ids, int, i);
+                       st = g_ptr_array_index(ctx->ctx->statfiles, id);
+
+                       if (st->stcf->class_name &&
+                               strcmp(st->stcf->class_name, cl.class_names[winning_class_idx]) == 0) {
+                               rspamd_task_insert_result(task, st->stcf->symbol, final_prob, sumbuf);
+                               break;
+                       }
+               }
+
+               msg_debug_bayes("multiclass classification: winning class '%s' with "
+                                               "probability %.3f, confidence %.3f, %uL tokens processed",
+                                               cl.class_names[winning_class_idx],
+                                               normalized_probs[winning_class_idx],
+                                               confidence, cl.processed_tokens);
+       }
+
+       return TRUE;
+}
+
 
 gboolean
 bayes_init(struct rspamd_config *cfg,
@@ -279,6 +588,28 @@ bayes_classify(struct rspamd_classifier *ctx,
        g_assert(ctx != NULL);
        g_assert(tokens != NULL);
 
+       /* Check if this is a multi-class classifier */
+       if (ctx->cfg->class_names && ctx->cfg->class_names->len >= 2) {
+               /* Verify that at least one statfile has class_name set (indicating new multi-class config) */
+               gboolean has_class_names = FALSE;
+               for (i = 0; i < ctx->statfiles_ids->len; i++) {
+                       int id = g_array_index(ctx->statfiles_ids, int, i);
+                       struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id);
+                       if (st->stcf->class_name) {
+                               has_class_names = TRUE;
+                               break;
+                       }
+               }
+
+               if (has_class_names) {
+                       msg_debug_bayes("using multiclass classification with %u classes",
+                                                       (unsigned int) ctx->cfg->class_names->len);
+                       return bayes_classify_multiclass(ctx, tokens, task);
+               }
+       }
+
+       /* Fall back to binary classification */
+       msg_debug_bayes("using binary classification");
        memset(&cl, 0, sizeof(cl));
        cl.task = task;
 
@@ -549,3 +880,152 @@ bayes_learn_spam(struct rspamd_classifier *ctx,
 
        return TRUE;
 }
+
+gboolean
+bayes_learn_class(struct rspamd_classifier *ctx,
+                                 GPtrArray *tokens,
+                                 struct rspamd_task *task,
+                                 const char *class_name,
+                                 gboolean unlearn,
+                                 GError **err)
+{
+       unsigned int i, j, total_cnt;
+       int id;
+       struct rspamd_statfile *st;
+       rspamd_token_t *tok;
+       gboolean incrementing;
+       unsigned int *class_counts = NULL;
+       struct rspamd_statfile **class_statfiles = NULL;
+       unsigned int num_classes = 0;
+
+       g_assert(ctx != NULL);
+       g_assert(tokens != NULL);
+       g_assert(class_name != NULL);
+
+       incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
+
+       /* Count classes and prepare arrays for multi-class learning */
+       if (ctx->cfg->class_names && ctx->cfg->class_names->len > 0) {
+               num_classes = ctx->cfg->class_names->len;
+               class_counts = g_alloca(num_classes * sizeof(unsigned int));
+               class_statfiles = g_alloca(num_classes * sizeof(struct rspamd_statfile *));
+               memset(class_counts, 0, num_classes * sizeof(unsigned int));
+               memset(class_statfiles, 0, num_classes * sizeof(struct rspamd_statfile *));
+       }
+
+       for (i = 0; i < tokens->len; i++) {
+               total_cnt = 0;
+               tok = g_ptr_array_index(tokens, i);
+
+               /* Reset class counts for this token */
+               if (num_classes > 0) {
+                       memset(class_counts, 0, num_classes * sizeof(unsigned int));
+               }
+
+               for (j = 0; j < ctx->statfiles_ids->len; j++) {
+                       id = g_array_index(ctx->statfiles_ids, int, j);
+                       st = g_ptr_array_index(ctx->ctx->statfiles, id);
+                       g_assert(st != NULL);
+
+                       /* Determine if this statfile matches our target class */
+                       gboolean is_target_class = FALSE;
+                       if (st->stcf->class_name) {
+                               /* Multi-class: exact class name match */
+                               is_target_class = (strcmp(st->stcf->class_name, class_name) == 0);
+                       }
+                       else {
+                               /* Legacy binary: map class_name to spam/ham */
+                               if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+                                       is_target_class = st->stcf->is_spam;
+                               }
+                               else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+                                       is_target_class = !st->stcf->is_spam;
+                               }
+                       }
+
+                       if (is_target_class) {
+                               /* Learning: increment the target class */
+                               if (incrementing) {
+                                       tok->values[id] = 1;
+                               }
+                               else {
+                                       tok->values[id]++;
+                               }
+                               total_cnt += tok->values[id];
+
+                               /* Track class counts for debugging */
+                               if (num_classes > 0) {
+                                       for (unsigned int k = 0; k < num_classes; k++) {
+                                               const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+                                               if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) {
+                                                       class_counts[k] += tok->values[id];
+                                                       class_statfiles[k] = st;
+                                                       break;
+                                               }
+                                       }
+                               }
+                       }
+                       else {
+                               /* Unlearning: decrement other classes if unlearn flag is set */
+                               if (tok->values[id] > 0 && unlearn) {
+                                       if (incrementing) {
+                                               tok->values[id] = -1;
+                                       }
+                                       else {
+                                               tok->values[id]--;
+                                       }
+                                       total_cnt += tok->values[id];
+
+                                       /* Track class counts for debugging */
+                                       if (num_classes > 0) {
+                                               for (unsigned int k = 0; k < num_classes; k++) {
+                                                       const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+                                                       if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) {
+                                                               class_counts[k] += tok->values[id];
+                                                               class_statfiles[k] = st;
+                                                               break;
+                                                       }
+                                               }
+                                       }
+                               }
+                               else if (incrementing) {
+                                       tok->values[id] = 0;
+                               }
+                       }
+               }
+
+               /* Debug logging */
+               if (tok->t1 && tok->t2) {
+                       if (num_classes > 0) {
+                               GString *debug_str = g_string_new("");
+                               for (unsigned int k = 0; k < num_classes; k++) {
+                                       const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+                                       g_string_append_printf(debug_str, "%s:%d ", check_class, class_counts[k]);
+                               }
+                               msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, "
+                                                               "class_counts: %s",
+                                                               tok->data,
+                                                               (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+                                                               (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+                                                               tok->window_idx, total_cnt, debug_str->str);
+                               g_string_free(debug_str, TRUE);
+                       }
+                       else {
+                               msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, "
+                                                               "class: %s",
+                                                               tok->data,
+                                                               (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+                                                               (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+                                                               tok->window_idx, total_cnt, class_name);
+                       }
+               }
+               else {
+                       msg_debug_bayes("token %uL <?:?>: window: %d, total_count: %d, "
+                                                       "class: %s",
+                                                       tok->data,
+                                                       tok->window_idx, total_cnt, class_name);
+               }
+       }
+
+       return TRUE;
+}
index 22978e67347a7d6f5ed9ef91888ad73eca276c52..cab658146e82d629abd55670fbfc5bc5bcf7d064 100644 (file)
@@ -54,6 +54,13 @@ struct rspamd_stat_classifier {
                                                                gboolean unlearn,
                                                                GError **err);
 
+       gboolean (*learn_class_func)(struct rspamd_classifier *ctx,
+                                                                GPtrArray *input,
+                                                                struct rspamd_task *task,
+                                                                const char *class_name,
+                                                                gboolean unlearn,
+                                                                GError **err);
+
        void (*fin_func)(struct rspamd_classifier *cl);
 };
 
@@ -73,6 +80,13 @@ gboolean bayes_learn_spam(struct rspamd_classifier *ctx,
                                                  gboolean unlearn,
                                                  GError **err);
 
+gboolean bayes_learn_class(struct rspamd_classifier *ctx,
+                                                  GPtrArray *tokens,
+                                                  struct rspamd_task *task,
+                                                  const char *class_name,
+                                                  gboolean unlearn,
+                                                  GError **err);
+
 void bayes_fin(struct rspamd_classifier *);
 
 /* Generic lua classifier */
index 811566ad382581a6d4059392612b2054206645ba..aa6111a8b2dfccf612f1883b25205292994a58f1 100644 (file)
@@ -107,6 +107,23 @@ rspamd_stat_result_t rspamd_stat_learn(struct rspamd_task *task,
                                                                           unsigned int stage,
                                                                           GError **err);
 
+/**
+ * Learn task as a specific class, task must be processed prior to this call
+ * @param task task to learn
+ * @param class_name name of the class to learn (e.g., "spam", "ham", "transactional")
+ * @param L lua state
+ * @param classifier NULL to learn all classifiers, name to learn a specific one
+ * @param stage learning stage
+ * @param err error returned
+ * @return TRUE if task has been learned
+ */
+rspamd_stat_result_t rspamd_stat_learn_class(struct rspamd_task *task,
+                                                                                        const char *class_name,
+                                                                                        lua_State *L,
+                                                                                        const char *classifier,
+                                                                                        unsigned int stage,
+                                                                                        GError **err);
+
 /**
  * Get the overall statistics for all statfile backends
  * @param cfg configuration
@@ -120,6 +137,43 @@ rspamd_stat_result_t rspamd_stat_statistics(struct rspamd_task *task,
 
 void rspamd_stat_unload(void);
 
+/**
+ * Multi-class classification result structure
+ */
+typedef struct {
+       char **class_names;        /**< Array of class names */
+       double *probabilities;     /**< Array of probabilities for each class */
+       unsigned int num_classes;  /**< Number of classes */
+       const char *winning_class; /**< Name of the winning class (reference, not owned) */
+       double confidence;         /**< Confidence of the winning class */
+} rspamd_multiclass_result_t;
+
+/**
+ * Set multi-class classification result for a task
+ */
+void rspamd_task_set_multiclass_result(struct rspamd_task *task,
+                                                                          rspamd_multiclass_result_t *result);
+
+/**
+ * Get multi-class classification result from a task
+ */
+rspamd_multiclass_result_t *rspamd_task_get_multiclass_result(struct rspamd_task *task);
+
+/**
+ * Free multi-class result structure
+ */
+void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result);
+
+/**
+ * Set autolearn class for a task
+ */
+void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name);
+
+/**
+ * Get autolearn class from a task
+ */
+const char *rspamd_task_get_autolearn_class(struct rspamd_task *task);
+
 #ifdef __cplusplus
 }
 #endif
index 8a5313df22db1d21509f581dfe606daac1e910f1..5ada7d468193e2b10b7db189742ce9969a058a63 100644 (file)
@@ -28,6 +28,7 @@ static struct rspamd_stat_classifier lua_classifier = {
        .init_func = lua_classifier_init,
        .classify_func = lua_classifier_classify,
        .learn_spam_func = lua_classifier_learn_spam,
+       .learn_class_func = NULL, /* TODO: implement lua multi-class learning */
        .fin_func = NULL,
 };
 
@@ -37,6 +38,7 @@ static struct rspamd_stat_classifier stat_classifiers[] = {
                .init_func = bayes_init,
                .classify_func = bayes_classify,
                .learn_spam_func = bayes_learn_spam,
+               .learn_class_func = bayes_learn_class,
                .fin_func = bayes_fin,
        }};
 
@@ -68,8 +70,7 @@ static struct rspamd_stat_tokenizer stat_tokenizers[] = {
                .dec_learns = rspamd_##eltn##_dec_learns,                       \
                .get_stat = rspamd_##eltn##_get_stat,                           \
                .load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \
-               .close = rspamd_##eltn##_close                                  \
-       }
+               .close = rspamd_##eltn##_close}
 #define RSPAMD_STAT_BACKEND_ELT_READONLY(nam, eltn)                     \
        {                                                                   \
                .name = #nam,                                                   \
@@ -85,8 +86,7 @@ static struct rspamd_stat_tokenizer stat_tokenizers[] = {
                .dec_learns = NULL,                                             \
                .get_stat = rspamd_##eltn##_get_stat,                           \
                .load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \
-               .close = rspamd_##eltn##_close                                  \
-       }
+               .close = rspamd_##eltn##_close}
 
 static struct rspamd_stat_backend stat_backends[] = {
        RSPAMD_STAT_BACKEND_ELT(mmap, mmaped_file),
@@ -101,8 +101,7 @@ static struct rspamd_stat_backend stat_backends[] = {
                .runtime = rspamd_stat_cache_##eltn##_runtime, \
                .check = rspamd_stat_cache_##eltn##_check,     \
                .learn = rspamd_stat_cache_##eltn##_learn,     \
-               .close = rspamd_stat_cache_##eltn##_close      \
-       }
+               .close = rspamd_stat_cache_##eltn##_close}
 
 static struct rspamd_stat_cache stat_caches[] = {
        RSPAMD_STAT_CACHE_ELT(sqlite3, sqlite3),
index 176064087bd2e047e74f830b937da200dfb33d9b..5126fd2cc3bea7219131fd4bd4e5f5606035b02a 100644 (file)
 
 static const double similarity_threshold = 80.0;
 
+void rspamd_task_set_multiclass_result(struct rspamd_task *task, rspamd_multiclass_result_t *result)
+{
+       g_assert(task != NULL);
+       g_assert(result != NULL);
+
+       rspamd_mempool_set_variable(task->task_pool, "multiclass_bayes_result", result,
+                                                               (rspamd_mempool_destruct_t) rspamd_multiclass_result_free);
+}
+
+rspamd_multiclass_result_t *
+rspamd_task_get_multiclass_result(struct rspamd_task *task)
+{
+       g_assert(task != NULL);
+
+       return (rspamd_multiclass_result_t *) rspamd_mempool_get_variable(task->task_pool,
+                                                                                                                                         "multiclass_bayes_result");
+}
+
+void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result)
+{
+       if (result == NULL) {
+               return;
+       }
+
+       g_free(result->class_names);
+       g_free(result->probabilities);
+       /* winning_class is a reference, not owned - don't free */
+       g_free(result);
+}
+
+void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name)
+{
+       g_assert(task != NULL);
+       g_assert(class_name != NULL);
+
+       /* Store the class name in the mempool */
+       const char *class_name_copy = rspamd_mempool_strdup(task->task_pool, class_name);
+       rspamd_mempool_set_variable(task->task_pool, "autolearn_class",
+                                                               (gpointer) class_name_copy, NULL);
+
+       /* Set the appropriate flags */
+       task->flags |= RSPAMD_TASK_FLAG_LEARN_CLASS;
+
+       /* For backward compatibility, also set binary flags */
+       if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+               task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+       }
+       else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+               task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+       }
+}
+
+const char *
+rspamd_task_get_autolearn_class(struct rspamd_task *task)
+{
+       g_assert(task != NULL);
+
+       if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) {
+               return (const char *) rspamd_mempool_get_variable(task->task_pool, "autolearn_class");
+       }
+
+       /* Fallback to binary flags for backward compatibility */
+       if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+               return "spam";
+       }
+       else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+               return "ham";
+       }
+
+       return NULL;
+}
+
 static void
 rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx,
                                                                        struct rspamd_task *task)
@@ -394,18 +466,9 @@ rspamd_stat_classifiers_process(struct rspamd_stat_ctx *st_ctx,
        }
 
        /*
-        * Do not classify a message if some class is missing
+        * Multi-class approach: don't check for missing classes
+        * Missing tokens naturally result in 0 probability
         */
-       if (!(task->flags & RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS)) {
-               msg_info_task("skip statistics as SPAM class is missing");
-
-               return;
-       }
-       if (!(task->flags & RSPAMD_TASK_FLAG_HAS_HAM_TOKENS)) {
-               msg_info_task("skip statistics as HAM class is missing");
-
-               return;
-       }
 
        for (i = 0; i < st_ctx->classifiers->len; i++) {
                cl = g_ptr_array_index(st_ctx->classifiers, i);
@@ -565,7 +628,24 @@ rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx,
 
                if (sel->cache && sel->cachecf) {
                        rt = cl->cache->runtime(task, sel->cachecf, FALSE);
-                       learn_res = cl->cache->check(task, spam, rt);
+
+                       /* For multi-class learning, determine spam boolean from class name if available */
+                       gboolean cache_spam = spam; /* Default to original spam parameter */
+                       const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+                       if (autolearn_class) {
+                               if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) {
+                                       cache_spam = TRUE;
+                               }
+                               else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) {
+                                       cache_spam = FALSE;
+                               }
+                               else {
+                                       /* For other classes, use a heuristic or default to spam for cache purposes */
+                                       cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */
+                               }
+                       }
+
+                       learn_res = cl->cache->check(task, cache_spam, rt);
                }
 
                if (learn_res == RSPAMD_LEARN_IGNORE) {
@@ -658,9 +738,63 @@ rspamd_stat_classifiers_learn(struct rspamd_stat_ctx *st_ctx,
                        continue;
                }
 
-               if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam,
-                                                                          task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
-                       learned = TRUE;
+               /* Check if classifier supports multi-class learning and if we should use it */
+               if (cl->subrs->learn_class_func && cl->cfg->class_names && cl->cfg->class_names->len > 2) {
+                       /* Multi-class learning: determine class name from task flags or autolearn result */
+                       const char *class_name = NULL;
+
+                       if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+                               /* Find spam class name */
+                               for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) {
+                                       const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k);
+                                       /* Look for statfile with this class that is spam */
+                                       GList *cur = cl->cfg->statfiles;
+                                       while (cur) {
+                                               struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+                                               if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && stcf->is_spam) {
+                                                       class_name = check_class;
+                                                       break;
+                                               }
+                                               cur = g_list_next(cur);
+                                       }
+                                       if (class_name) break;
+                               }
+                               if (!class_name) class_name = "spam"; /* fallback */
+                       }
+                       else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+                               /* Find ham class name */
+                               for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) {
+                                       const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k);
+                                       /* Look for statfile with this class that is ham */
+                                       GList *cur = cl->cfg->statfiles;
+                                       while (cur) {
+                                               struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+                                               if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && !stcf->is_spam) {
+                                                       class_name = check_class;
+                                                       break;
+                                               }
+                                               cur = g_list_next(cur);
+                                       }
+                                       if (class_name) break;
+                               }
+                               if (!class_name) class_name = "ham"; /* fallback */
+                       }
+                       else {
+                               /* Fallback to spam/ham based on the spam parameter */
+                               class_name = spam ? "spam" : "ham";
+                       }
+
+                       if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name,
+                                                                                       task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+                               learned = TRUE;
+                       }
+               }
+               else {
+                       /* Binary learning: use existing function */
+                       if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam,
+                                                                                  task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+                               learned = TRUE;
+                       }
                }
        }
 
@@ -870,7 +1004,24 @@ rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx,
 
                if (cl->cache) {
                        cache_run = cl->cache->runtime(task, cl->cachecf, TRUE);
-                       cl->cache->learn(task, spam, cache_run);
+
+                       /* For multi-class learning, determine spam boolean from class name if available */
+                       gboolean cache_spam = spam; /* Default to original spam parameter */
+                       const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+                       if (autolearn_class) {
+                               if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) {
+                                       cache_spam = TRUE;
+                               }
+                               else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) {
+                                       cache_spam = FALSE;
+                               }
+                               else {
+                                       /* For other classes, use a heuristic or default to spam for cache purposes */
+                                       cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */
+                               }
+                       }
+
+                       cl->cache->learn(task, cache_spam, cache_run);
                }
        }
 
@@ -879,6 +1030,218 @@ rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx,
        return res;
 }
 
+static gboolean
+rspamd_stat_classifiers_learn_class(struct rspamd_stat_ctx *st_ctx,
+                                                                       struct rspamd_task *task,
+                                                                       const char *classifier,
+                                                                       const char *class_name,
+                                                                       GError **err)
+{
+       struct rspamd_classifier *cl, *sel = NULL;
+       unsigned int i;
+       gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
+
+       if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
+               *err == NULL) {
+               /* Do not learn twice */
+               g_set_error(err, rspamd_stat_quark(), 208, "<%s> has been already "
+                                                                                                  "learned as %s, ignore it",
+                                       MESSAGE_FIELD(task, message_id),
+                                       class_name);
+
+               return FALSE;
+       }
+
+       /* Check whether we have learned that file */
+       for (i = 0; i < st_ctx->classifiers->len; i++) {
+               cl = g_ptr_array_index(st_ctx->classifiers, i);
+
+               /* Skip other classifiers if they are not needed */
+               if (classifier != NULL && (cl->cfg->name == NULL ||
+                                                                  g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
+                       continue;
+               }
+
+               sel = cl;
+
+               /* Now check max and min tokens */
+               if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
+                       msg_info_task(
+                               "<%s> contains less tokens than required for %s classifier: "
+                               "%ud < %ud",
+                               MESSAGE_FIELD(task, message_id),
+                               cl->cfg->name,
+                               task->tokens->len,
+                               cl->cfg->min_tokens);
+                       too_small = TRUE;
+                       continue;
+               }
+               else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
+                       msg_info_task(
+                               "<%s> contains more tokens than allowed for %s classifier: "
+                               "%ud > %ud",
+                               MESSAGE_FIELD(task, message_id),
+                               cl->cfg->name,
+                               task->tokens->len,
+                               cl->cfg->max_tokens);
+                       too_large = TRUE;
+                       continue;
+               }
+
+               /* Use the new multi-class learning function if available */
+               if (cl->subrs->learn_class_func) {
+                       if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name,
+                                                                                       task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+                               learned = TRUE;
+                       }
+               }
+               else {
+                       /* Fallback to binary learning with class name mapping */
+                       gboolean is_spam;
+                       if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+                               is_spam = TRUE;
+                       }
+                       else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+                               is_spam = FALSE;
+                       }
+                       else {
+                               /* For unknown classes with binary classifier, skip */
+                               msg_info_task("skipping class '%s' for binary classifier %s",
+                                                         class_name, cl->cfg->name);
+                               continue;
+                       }
+
+                       if (cl->subrs->learn_spam_func(cl, task->tokens, task, is_spam,
+                                                                                  task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+                               learned = TRUE;
+                       }
+               }
+       }
+
+       if (sel == NULL) {
+               if (classifier) {
+                       g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
+                                                                                                          "with name %s",
+                                               classifier);
+               }
+               else {
+                       g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
+               }
+
+               return FALSE;
+       }
+
+       if (!learned && err && *err == NULL) {
+               if (too_large) {
+                       g_set_error(err, rspamd_stat_quark(), 204,
+                                               "<%s> contains more tokens than allowed for %s classifier: "
+                                               "%d > %d",
+                                               MESSAGE_FIELD(task, message_id),
+                                               sel->cfg->name,
+                                               task->tokens->len,
+                                               sel->cfg->max_tokens);
+               }
+               else if (too_small) {
+                       g_set_error(err, rspamd_stat_quark(), 204,
+                                               "<%s> contains less tokens than required for %s classifier: "
+                                               "%d < %d",
+                                               MESSAGE_FIELD(task, message_id),
+                                               sel->cfg->name,
+                                               task->tokens->len,
+                                               sel->cfg->min_tokens);
+               }
+       }
+
+       return learned;
+}
+
+rspamd_stat_result_t
+rspamd_stat_learn_class(struct rspamd_task *task,
+                                               const char *class_name,
+                                               lua_State *L,
+                                               const char *classifier,
+                                               unsigned int stage,
+                                               GError **err)
+{
+       struct rspamd_stat_ctx *st_ctx;
+       rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
+
+       /*
+        * We assume now that a task has been already classified before
+        * coming to learn
+        */
+       g_assert(RSPAMD_TASK_IS_CLASSIFIED(task));
+
+       st_ctx = rspamd_stat_get_ctx();
+       g_assert(st_ctx != NULL);
+
+       msg_debug_bayes("learn class stage %d has been called for class '%s'", stage, class_name);
+
+       if (st_ctx->classifiers->len == 0) {
+               msg_debug_bayes("no classifiers defined");
+               task->processed_stages |= stage;
+               return ret;
+       }
+
+       if (task->message == NULL) {
+               ret = RSPAMD_STAT_PROCESS_ERROR;
+               if (err && *err == NULL) {
+                       g_set_error(err, rspamd_stat_quark(), 500,
+                                               "Trying to learn an empty message");
+               }
+
+               task->processed_stages |= stage;
+               return ret;
+       }
+
+       if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
+               /* Process classifiers - determine spam boolean for compatibility */
+               gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+               rspamd_stat_preprocess(st_ctx, task, TRUE, spam);
+
+               if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) {
+                       msg_debug_bayes("cache check failed, skip learning");
+                       return RSPAMD_STAT_PROCESS_ERROR;
+               }
+       }
+       else if (stage == RSPAMD_TASK_STAGE_LEARN) {
+               /* Process classifiers */
+               if (!rspamd_stat_classifiers_learn_class(st_ctx, task, classifier,
+                                                                                                class_name, err)) {
+                       if (err && *err == NULL) {
+                               g_set_error(err, rspamd_stat_quark(), 500,
+                                                       "Unknown statistics error, found when learning classifiers;"
+                                                       " classifier: %s",
+                                                       task->classifier);
+                       }
+                       return RSPAMD_STAT_PROCESS_ERROR;
+               }
+
+               /* Process backends - determine spam boolean for compatibility */
+               gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+               if (!rspamd_stat_backends_learn(st_ctx, task, classifier, spam, err)) {
+                       if (err && *err == NULL) {
+                               g_set_error(err, rspamd_stat_quark(), 500,
+                                                       "Unknown statistics error, found when storing data on backend;"
+                                                       " classifier: %s",
+                                                       task->classifier);
+                       }
+                       return RSPAMD_STAT_PROCESS_ERROR;
+               }
+       }
+       else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) {
+               /* Process backends - determine spam boolean for compatibility */
+               gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+               if (!rspamd_stat_backends_post_learn(st_ctx, task, classifier, spam, err)) {
+                       return RSPAMD_STAT_PROCESS_ERROR;
+               }
+       }
+
+       task->processed_stages |= stage;
+
+       return ret;
+}
+
 rspamd_stat_result_t
 rspamd_stat_learn(struct rspamd_task *task,
                                  gboolean spam, lua_State *L, const char *classifier, unsigned int stage,
@@ -1039,12 +1402,11 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
 
                                        if (mres) {
                                                if (mres->score > rspamd_task_get_required_score(task, mres)) {
-                                                       task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
-
+                                                       rspamd_task_set_autolearn_class(task, "spam");
                                                        ret = TRUE;
                                                }
                                                else if (mres->score < 0) {
-                                                       task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+                                                       rspamd_task_set_autolearn_class(task, "ham");
                                                        ret = TRUE;
                                                }
                                        }
@@ -1076,12 +1438,11 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
 
                                        if (mres) {
                                                if (mres->score >= spam_score) {
-                                                       task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
-
+                                                       rspamd_task_set_autolearn_class(task, "spam");
                                                        ret = TRUE;
                                                }
                                                else if (mres->score <= ham_score) {
-                                                       task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+                                                       rspamd_task_set_autolearn_class(task, "ham");
                                                        ret = TRUE;
                                                }
                                        }
@@ -1117,11 +1478,16 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
                                                        /* We can have immediate results */
                                                        if (lua_ret) {
                                                                if (strcmp(lua_ret, "ham") == 0) {
-                                                                       task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+                                                                       rspamd_task_set_autolearn_class(task, "ham");
                                                                        ret = TRUE;
                                                                }
                                                                else if (strcmp(lua_ret, "spam") == 0) {
-                                                                       task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+                                                                       rspamd_task_set_autolearn_class(task, "spam");
+                                                                       ret = TRUE;
+                                                               }
+                                                               else {
+                                                                       /* Multi-class: any other class name */
+                                                                       rspamd_task_set_autolearn_class(task, lua_ret);
                                                                        ret = TRUE;
                                                                }
                                                        }
@@ -1139,79 +1505,138 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
                                }
                        }
                        else if (ucl_object_type(obj) == UCL_OBJECT) {
-                               /* Try to find autolearn callback */
-                               if (cl->autolearn_cbref == 0) {
-                                       /* We don't have preprocessed cb id, so try to get it */
-                                       if (!rspamd_lua_require_function(L, "lua_bayes_learn",
-                                                                                                        "autolearn")) {
-                                               msg_err_task("cannot get autolearn library from "
-                                                                        "`lua_bayes_learn`");
-                                       }
-                                       else {
-                                               cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX);
+                               /* Check if this is a multi-class autolearn configuration */
+                               const ucl_object_t *multiclass_obj = ucl_object_lookup(obj, "multiclass");
+
+                               if (multiclass_obj && ucl_object_type(multiclass_obj) == UCL_OBJECT) {
+                                       /* Multi-class threshold-based autolearn */
+                                       const ucl_object_t *thresholds_obj = ucl_object_lookup(multiclass_obj, "thresholds");
+
+                                       if (thresholds_obj && ucl_object_type(thresholds_obj) == UCL_OBJECT) {
+                                               /* Iterate through class thresholds */
+                                               ucl_object_iter_t it = NULL;
+                                               const ucl_object_t *class_obj;
+                                               const char *class_name;
+
+                                               while ((class_obj = ucl_object_iterate(thresholds_obj, &it, true))) {
+                                                       class_name = ucl_object_key(class_obj);
+
+                                                       if (class_name && ucl_object_type(class_obj) == UCL_ARRAY && class_obj->len == 2) {
+                                                               /* [min_score, max_score] for this class */
+                                                               const ucl_object_t *min_elt = ucl_array_find_index(class_obj, 0);
+                                                               const ucl_object_t *max_elt = ucl_array_find_index(class_obj, 1);
+
+                                                               if ((ucl_object_type(min_elt) == UCL_FLOAT || ucl_object_type(min_elt) == UCL_INT) &&
+                                                                       (ucl_object_type(max_elt) == UCL_FLOAT || ucl_object_type(max_elt) == UCL_INT)) {
+
+                                                                       double min_score = ucl_object_todouble(min_elt);
+                                                                       double max_score = ucl_object_todouble(max_elt);
+
+                                                                       if (mres && mres->score >= min_score && mres->score <= max_score) {
+                                                                               rspamd_task_set_autolearn_class(task, class_name);
+                                                                               ret = TRUE;
+                                                                               msg_debug_bayes("multiclass autolearn: score %.2f matches class '%s' [%.2f, %.2f]",
+                                                                                                               mres->score, class_name, min_score, max_score);
+                                                                               break; /* Stop at first matching class */
+                                                                       }
+                                                               }
+                                                       }
+                                               }
                                        }
                                }
-
-                               if (cl->autolearn_cbref != -1) {
-                                       lua_pushcfunction(L, &rspamd_lua_traceback);
-                                       err_idx = lua_gettop(L);
-                                       lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
-
-                                       ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
-                                       *ptask = task;
-                                       rspamd_lua_setclass(L, rspamd_task_classname, -1);
-                                       /* Push the whole object as well */
-                                       ucl_object_push_lua(L, obj, true);
-
-                                       if (lua_pcall(L, 2, 1, err_idx) != 0) {
-                                               msg_err_task("call to autolearn script failed: "
-                                                                        "%s",
-                                                                        lua_tostring(L, -1));
+                               else {
+                                       /* Try to find autolearn callback */
+                                       if (cl->autolearn_cbref == 0) {
+                                               /* We don't have preprocessed cb id, so try to get it */
+                                               if (!rspamd_lua_require_function(L, "lua_bayes_learn",
+                                                                                                                "autolearn")) {
+                                                       msg_err_task("cannot get autolearn library from "
+                                                                                "`lua_bayes_learn`");
+                                               }
+                                               else {
+                                                       cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX);
+                                               }
                                        }
-                                       else {
-                                               lua_ret = lua_tostring(L, -1);
 
-                                               if (lua_ret) {
-                                                       if (strcmp(lua_ret, "ham") == 0) {
-                                                               task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
-                                                               ret = TRUE;
-                                                       }
-                                                       else if (strcmp(lua_ret, "spam") == 0) {
-                                                               task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
-                                                               ret = TRUE;
+                                       if (cl->autolearn_cbref != -1) {
+                                               lua_pushcfunction(L, &rspamd_lua_traceback);
+                                               err_idx = lua_gettop(L);
+                                               lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
+
+                                               ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
+                                               *ptask = task;
+                                               rspamd_lua_setclass(L, rspamd_task_classname, -1);
+                                               /* Push the whole object as well */
+                                               ucl_object_push_lua(L, obj, true);
+
+                                               if (lua_pcall(L, 2, 1, err_idx) != 0) {
+                                                       msg_err_task("call to autolearn script failed: "
+                                                                                "%s",
+                                                                                lua_tostring(L, -1));
+                                               }
+                                               else {
+                                                       lua_ret = lua_tostring(L, -1);
+
+                                                       if (lua_ret) {
+                                                               if (strcmp(lua_ret, "ham") == 0) {
+                                                                       rspamd_task_set_autolearn_class(task, "ham");
+                                                                       ret = TRUE;
+                                                               }
+                                                               else if (strcmp(lua_ret, "spam") == 0) {
+                                                                       rspamd_task_set_autolearn_class(task, "spam");
+                                                                       ret = TRUE;
+                                                               }
+                                                               else {
+                                                                       /* Multi-class: any other class name */
+                                                                       rspamd_task_set_autolearn_class(task, lua_ret);
+                                                                       ret = TRUE;
+                                                               }
                                                        }
                                                }
-                                       }
 
-                                       lua_settop(L, err_idx - 1);
+                                               lua_settop(L, err_idx - 1);
+                                       }
                                }
-                       }
 
-                       if (ret) {
-                               /* Do not autolearn if we have this symbol already */
-                               if (rspamd_stat_has_classifier_symbols(task, mres, cl)) {
-                                       ret = FALSE;
-                                       task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM |
-                                                                        RSPAMD_TASK_FLAG_LEARN_SPAM);
-                               }
-                               else if (mres != NULL) {
-                                       if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
-                                               msg_info_task("<%s>: autolearn ham for classifier "
-                                                                         "'%s' as message's "
-                                                                         "score is negative: %.2f",
-                                                                         MESSAGE_FIELD(task, message_id), cl->cfg->name,
-                                                                         mres->score);
-                                       }
-                                       else {
-                                               msg_info_task("<%s>: autolearn spam for classifier "
-                                                                         "'%s' as message's "
-                                                                         "action is reject, score: %.2f",
-                                                                         MESSAGE_FIELD(task, message_id), cl->cfg->name,
-                                                                         mres->score);
+                               if (ret) {
+                                       /* Do not autolearn if we have this symbol already */
+                                       if (rspamd_stat_has_classifier_symbols(task, mres, cl)) {
+                                               ret = FALSE;
+                                               task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM |
+                                                                                RSPAMD_TASK_FLAG_LEARN_SPAM |
+                                                                                RSPAMD_TASK_FLAG_LEARN_CLASS);
+                                               /* Clear the autolearn class from mempool */
+                                               rspamd_mempool_set_variable(task->task_pool, "autolearn_class", NULL, NULL);
                                        }
+                                       else if (mres != NULL) {
+                                               const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+
+                                               if (autolearn_class) {
+                                                       if (strcmp(autolearn_class, "ham") == 0) {
+                                                               msg_info_task("<%s>: autolearn ham for classifier "
+                                                                                         "'%s' as message's "
+                                                                                         "score is negative: %.2f",
+                                                                                         MESSAGE_FIELD(task, message_id), cl->cfg->name,
+                                                                                         mres->score);
+                                                       }
+                                                       else if (strcmp(autolearn_class, "spam") == 0) {
+                                                               msg_info_task("<%s>: autolearn spam for classifier "
+                                                                                         "'%s' as message's "
+                                                                                         "action is reject, score: %.2f",
+                                                                                         MESSAGE_FIELD(task, message_id), cl->cfg->name,
+                                                                                         mres->score);
+                                                       }
+                                                       else {
+                                                               msg_info_task("<%s>: autolearn class '%s' for classifier "
+                                                                                         "'%s', score: %.2f",
+                                                                                         MESSAGE_FIELD(task, message_id), autolearn_class,
+                                                                                         cl->cfg->name, mres->score);
+                                                       }
+                                               }
 
-                                       task->classifier = cl->cfg->name;
-                                       break;
+                                               task->classifier = cl->cfg->name;
+                                               break;
+                                       }
                                }
                        }
                }