From: Vsevolod Stakhov Date: Sun, 20 Jul 2025 16:11:52 +0000 (+0100) Subject: [Project] Multi-class classification project baseline X-Git-Tag: 3.13.0~38^2~23 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=5e0ac6790d5a8dc92ae1ee6bff7372063ed29b03;p=thirdparty%2Frspamd.git [Project] Multi-class classification project baseline --- diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua index 782e6fc472..59952131ab 100644 --- a/lualib/lua_bayes_redis.lua +++ b/lualib/lua_bayes_redis.lua @@ -25,25 +25,56 @@ local ucl = require "ucl" local N = "bayes" local function gen_classify_functor(redis_params, classify_script_id) - return function(task, expanded_key, id, is_spam, stat_tokens, callback) - + return function(task, expanded_key, id, class_labels, stat_tokens, callback) local function classify_redis_cb(err, data) lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data) if err then callback(task, false, err) else - callback(task, true, data[1], data[2], data[3], data[4]) + -- Handle both binary and multi-class results + if type(data[1]) == "table" then + -- Multi-class format: [learned_counts_table, outputs_table] + -- Convert to binary format for backward compatibility if needed + local learned_counts = data[1] + local outputs = data[2] + + -- For now, return ham/spam data if available for backward compatibility + local learned_ham = learned_counts["H"] or learned_counts["ham"] or 0 + local learned_spam = learned_counts["S"] or learned_counts["spam"] or 0 + local output_ham = outputs["H"] or outputs["ham"] or {} + local output_spam = outputs["S"] or outputs["spam"] or {} + + callback(task, true, learned_ham, learned_spam, output_ham, output_spam) + else + -- Binary format: [learned_ham, learned_spam, output_ham, output_spam] + callback(task, true, data[1], data[2], data[3], data[4]) + end + end + end + + -- Determine class labels to send to Redis script + local script_class_labels + if type(class_labels) == "table" then + script_class_labels = class_labels + else + -- Single class label or boolean compatibility + if class_labels == true or class_labels == "true" then + script_class_labels = "S" -- spam + elseif class_labels == false or class_labels == "false" then + script_class_labels = "H" -- ham + else + script_class_labels = class_labels -- string class label end end lua_redis.exec_redis_script(classify_script_id, - { task = task, is_write = false, key = expanded_key }, - classify_redis_cb, { expanded_key, stat_tokens }) + { task = task, is_write = false, key = expanded_key }, + classify_redis_cb, { expanded_key, script_class_labels, stat_tokens }) end end local function gen_learn_functor(redis_params, learn_script_id) - return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens) + return function(task, expanded_key, id, class_label, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens) local function learn_redis_cb(err, data) lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data) if err then @@ -53,17 +84,24 @@ local function gen_learn_functor(redis_params, learn_script_id) end end + -- Convert class_label for backward compatibility + local script_class_label = class_label + if class_label == true or class_label == "true" then + script_class_label = "S" -- spam + elseif class_label == false or class_label == "false" then + script_class_label = "H" -- ham + end + if maybe_text_tokens then lua_redis.exec_redis_script(learn_script_id, - { task = task, is_write = true, key = expanded_key }, - learn_redis_cb, - { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens }) + { task = task, is_write = true, key = expanded_key }, + learn_redis_cb, + { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens }) else lua_redis.exec_redis_script(learn_script_id, - { task = task, is_write = true, key = expanded_key }, - learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens }) + { task = task, is_write = true, key = expanded_key }, + learn_redis_cb, { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens }) end - end end @@ -112,8 +150,7 @@ end --- @param classifier_ucl ucl of the classifier config --- @param statfile_ucl ucl of the statfile config --- @return a pair of (classify_functor, learn_functor) or `nil` in case of error -exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb) - +exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, class_label, ev_base, stat_periodic_cb) local redis_params = load_redis_params(classifier_ucl, statfile_ucl) if not redis_params then @@ -137,7 +174,6 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, if ev_base then rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _) - local function stat_redis_cb(err, data) lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data) @@ -162,12 +198,23 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, end end + -- Convert class_label to learn key + local learn_key + if class_label == true or class_label == "true" or class_label == "S" then + learn_key = "learns_spam" + elseif class_label == false or class_label == "false" or class_label == "H" then + learn_key = "learns_ham" + else + -- For other class labels, use learns_ + learn_key = "learns_" .. string.lower(tostring(class_label)) + end + lua_redis.exec_redis_script(stat_script_id, - { ev_base = ev_base, cfg = cfg, is_write = false }, - stat_redis_cb, { tostring(cursor), - symbol, - is_spam and "learns_spam" or "learns_ham", - tostring(max_users) }) + { ev_base = ev_base, cfg = cfg, is_write = false }, + stat_redis_cb, { tostring(cursor), + symbol, + learn_key, + tostring(max_users) }) return statfile_ucl.monitor_timeout or classifier_ucl.monitor_timeout or 30.0 end) end @@ -178,7 +225,6 @@ end local function gen_cache_check_functor(redis_params, check_script_id, conf) local packed_conf = ucl.to_format(conf, 'msgpack') return function(task, cache_id, callback) - local function classify_redis_cb(err, data) lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data)) if err then @@ -194,24 +240,33 @@ local function gen_cache_check_functor(redis_params, check_script_id, conf) lua_util.debugm(N, task, 'checking cache: %s', cache_id) lua_redis.exec_redis_script(check_script_id, - { task = task, is_write = false, key = cache_id }, - classify_redis_cb, { cache_id, packed_conf }) + { task = task, is_write = false, key = cache_id }, + classify_redis_cb, { cache_id, packed_conf }) end end local function gen_cache_learn_functor(redis_params, learn_script_id, conf) local packed_conf = ucl.to_format(conf, 'msgpack') - return function(task, cache_id, is_spam) + return function(task, cache_id, class_name) local function learn_redis_cb(err, data) lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data) end - lua_util.debugm(N, task, 'try to learn cache: %s', cache_id) - lua_redis.exec_redis_script(learn_script_id, - { task = task, is_write = true, key = cache_id }, - learn_redis_cb, - { cache_id, is_spam and "1" or "0", packed_conf }) + -- Handle backward compatibility for boolean values + local cache_class_name = class_name + if type(class_name) == "boolean" then + cache_class_name = class_name and "spam" or "ham" + elseif class_name == true or class_name == "true" then + cache_class_name = "spam" + elseif class_name == false or class_name == "false" then + cache_class_name = "ham" + end + lua_util.debugm(N, task, 'try to learn cache: %s as %s', cache_id, cache_class_name) + lua_redis.exec_redis_script(learn_script_id, + { task = task, is_write = true, key = cache_id }, + learn_redis_cb, + { cache_id, cache_class_name, packed_conf }) end end @@ -225,8 +280,8 @@ exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl) local default_conf = { cache_prefix = "learned_ids", cache_max_elt = 10000, -- Maximum number of elements in the cache key - cache_max_keys = 5, -- Maximum number of keys in the cache - cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value) + cache_max_keys = 5, -- Maximum number of keys in the cache + cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value) } local conf = lua_util.override_defaults(default_conf, classifier_ucl) @@ -241,7 +296,7 @@ exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl) local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params) return gen_cache_check_functor(redis_params, check_script_id, conf), gen_cache_learn_functor(redis_params, - learn_script_id, conf) + learn_script_id, conf) end return exports diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua index 7d44a73efc..b3e15a9bc3 100644 --- a/lualib/redis_scripts/bayes_cache_learn.lua +++ b/lualib/redis_scripts/bayes_cache_learn.lua @@ -1,12 +1,19 @@ --- Lua script to perform cache checking for bayes classification +-- Lua script to perform cache checking for bayes classification (multi-class) -- This script accepts the following parameters: -- key1 - cache id --- key2 - is spam (1 or 0) +-- key2 - class name (e.g. "spam", "ham", "transactional") -- key3 - configuration table in message pack local cache_id = KEYS[1] -local is_spam = KEYS[2] +local class_name = KEYS[2] local conf = cmsgpack.unpack(KEYS[3]) + +-- Handle backward compatibility for binary values +if class_name == "1" then + class_name = "spam" +elseif class_name == "0" then + class_name = "ham" +end cache_id = string.sub(cache_id, 1, conf.cache_elt_len) -- Try each prefix that is in Redis (as some other instance might have set it) @@ -15,8 +22,8 @@ for i = 0, conf.cache_max_keys do local have = redis.call('HGET', prefix, cache_id) if have then - -- Already in cache, but is_spam changes when relearning - redis.call('HSET', prefix, cache_id, is_spam) + -- Already in cache, but class_name changes when relearning + redis.call('HSET', prefix, cache_id, class_name) return false end end @@ -30,7 +37,7 @@ for i = 0, conf.cache_max_keys do if count < lim then -- We can add it to this prefix - redis.call('HSET', prefix, cache_id, is_spam) + redis.call('HSET', prefix, cache_id, class_name) added = true end end @@ -46,7 +53,7 @@ if not added then if exists then if not expired then redis.call('DEL', prefix) - redis.call('HSET', prefix, cache_id, is_spam) + redis.call('HSET', prefix, cache_id, class_name) -- Do not expire anything else expired = true diff --git a/lualib/redis_scripts/bayes_classify.lua b/lualib/redis_scripts/bayes_classify.lua index e94f645fdf..8e6feb32f8 100644 --- a/lualib/redis_scripts/bayes_classify.lua +++ b/lualib/redis_scripts/bayes_classify.lua @@ -1,37 +1,90 @@ --- Lua script to perform bayes classification +-- Lua script to perform bayes classification (multi-class) -- This script accepts the following parameters: -- key1 - prefix for bayes tokens (e.g. for per-user classification) --- key2 - set of tokens encoded in messagepack array of strings +-- key2 - class labels: either table of all class labels (multi-class) or single string (binary) +-- key3 - set of tokens encoded in messagepack array of strings local prefix = KEYS[1] -local output_spam = {} -local output_ham = {} +local class_labels_arg = KEYS[2] +local input_tokens = cmsgpack.unpack(KEYS[3]) -local learned_ham = tonumber(redis.call('HGET', prefix, 'learns_ham')) or 0 -local learned_spam = tonumber(redis.call('HGET', prefix, 'learns_spam')) or 0 +-- Determine if this is multi-class (table) or binary (string) +local class_labels = {} +if type(class_labels_arg) == "table" then + class_labels = class_labels_arg +else + -- Binary compatibility: handle old boolean or single string format + if class_labels_arg == "true" then + class_labels = { "S" } -- spam + elseif class_labels_arg == "false" then + class_labels = { "H" } -- ham + else + class_labels = { class_labels_arg } -- single class label + end +end --- Output is a set of pairs (token_index, token_count), tokens that are not --- found are not filled. --- This optimisation will save a lot of space for sparse tokens, and in Bayes that assumption is normally held +-- Get learned counts for all classes +local learned_counts = {} +for _, label in ipairs(class_labels) do + local key = 'learns_' .. string.lower(label) + -- Also try legacy keys for backward compatibility + if label == 'H' then + key = 'learns_ham' + elseif label == 'S' then + key = 'learns_spam' + end + learned_counts[label] = tonumber(redis.call('HGET', prefix, key)) or 0 +end -if learned_ham > 0 and learned_spam > 0 then - local input_tokens = cmsgpack.unpack(KEYS[2]) - for i, token in ipairs(input_tokens) do - local token_data = redis.call('HMGET', token, 'H', 'S') +-- Get token data for all classes (only if we have learns for any class) +local outputs = {} +local has_learns = false +for _, count in pairs(learned_counts) do + if count > 0 then + has_learns = true + break + end +end - if token_data then - local ham_count = token_data[1] - local spam_count = token_data[2] +if has_learns then + -- Initialize outputs for each class + for _, label in ipairs(class_labels) do + outputs[label] = {} + end - if ham_count then - table.insert(output_ham, { i, tonumber(ham_count) }) - end + -- Process each token + for i, token in ipairs(input_tokens) do + local token_data = redis.call('HMGET', token, unpack(class_labels)) - if spam_count then - table.insert(output_spam, { i, tonumber(spam_count) }) + if token_data then + for j, label in ipairs(class_labels) do + local count = token_data[j] + if count then + table.insert(outputs[label], { i, tonumber(count) }) + end end end end end -return { learned_ham, learned_spam, output_ham, output_spam } \ No newline at end of file +-- Format output for backward compatibility +if #class_labels == 2 and class_labels[1] == 'H' and class_labels[2] == 'S' then + -- Binary format: [learned_ham, learned_spam, output_ham, output_spam] + return { + learned_counts['H'] or 0, + learned_counts['S'] or 0, + outputs['H'] or {}, + outputs['S'] or {} + } +elseif #class_labels == 2 and class_labels[1] == 'S' and class_labels[2] == 'H' then + -- Binary format: [learned_ham, learned_spam, output_ham, output_spam] + return { + learned_counts['H'] or 0, + learned_counts['S'] or 0, + outputs['H'] or {}, + outputs['S'] or {} + } +else + -- Multi-class format: [learned_counts_table, outputs_table] + return { learned_counts, outputs } +end diff --git a/lualib/redis_scripts/bayes_learn.lua b/lualib/redis_scripts/bayes_learn.lua index 5456165b69..b284a28123 100644 --- a/lualib/redis_scripts/bayes_learn.lua +++ b/lualib/redis_scripts/bayes_learn.lua @@ -1,14 +1,14 @@ --- Lua script to perform bayes learning +-- Lua script to perform bayes learning (multi-class) -- This script accepts the following parameters: -- key1 - prefix for bayes tokens (e.g. for per-user classification) --- key2 - boolean is_spam +-- key2 - class label string (e.g. "S", "H", "T") -- key3 - string symbol -- key4 - boolean is_unlearn -- key5 - set of tokens encoded in messagepack array of strings -- key6 - set of text tokens (if any) encoded in messagepack array of strings (size must be twice of `KEYS[5]`) local prefix = KEYS[1] -local is_spam = KEYS[2] == 'true' and true or false +local class_label = KEYS[2] local symbol = KEYS[3] local is_unlearn = KEYS[4] == 'true' and true or false local input_tokens = cmsgpack.unpack(KEYS[5]) @@ -18,11 +18,25 @@ if KEYS[6] then text_tokens = cmsgpack.unpack(KEYS[6]) end -local hash_key = is_spam and 'S' or 'H' -local learned_key = is_spam and 'learns_spam' or 'learns_ham' +-- Handle backward compatibility for boolean values +if class_label == 'true' then + class_label = 'S' -- spam +elseif class_label == 'false' then + class_label = 'H' -- ham +end + +local hash_key = class_label +local learned_key = 'learns_' .. string.lower(class_label) + +-- Handle legacy keys for backward compatibility +if class_label == 'S' then + learned_key = 'learns_spam' +elseif class_label == 'H' then + learned_key = 'learns_ham' +end redis.call('SADD', symbol .. '_keys', prefix) -redis.call('HSET', prefix, 'version', '2') -- new schema +redis.call('HSET', prefix, 'version', '2') -- new schema redis.call('HINCRBY', prefix, learned_key, is_unlearn and -1 or 1) -- increase or decrease learned count for i, token in ipairs(input_tokens) do diff --git a/src/libserver/cfg_file.h b/src/libserver/cfg_file.h index 36941da7ad..cd2ab43141 100644 --- a/src/libserver/cfg_file.h +++ b/src/libserver/cfg_file.h @@ -139,7 +139,8 @@ struct rspamd_statfile_config { char *symbol; /**< symbol of statfile */ char *label; /**< label of this statfile */ ucl_object_t *opts; /**< other options */ - gboolean is_spam; /**< spam flag */ + char *class_name; /**< class name for multi-class classification */ + gboolean is_spam; /**< DEPRECATED: spam flag - use class_name instead */ struct rspamd_classifier_config *clcf; /**< parent pointer of classifier configuration */ gpointer data; /**< opaque data */ }; @@ -182,6 +183,8 @@ struct rspamd_classifier_config { double min_prob_strength; /**< use only tokens with probability in [0.5 - MPS, 0.5 + MPS] */ unsigned int min_learns; /**< minimum number of learns for each statfile */ unsigned int flags; + GHashTable *class_labels; /**< class_name -> backend_symbol mapping for multi-class */ + GPtrArray *class_names; /**< ordered list of class names */ }; struct rspamd_worker_bind_conf { @@ -621,12 +624,25 @@ void rspamd_config_insert_classify_symbols(struct rspamd_config *cfg); */ gboolean rspamd_config_check_statfiles(struct rspamd_classifier_config *cf); -/* - * Find classifier config by name +/** + * Multi-class configuration helpers + */ +gboolean rspamd_config_parse_class_labels(ucl_object_t *obj, + GHashTable **class_labels); + +gboolean rspamd_config_migrate_binary_config(struct rspamd_statfile_config *stcf); + +gboolean rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf, + GError **err); + +const char *rspamd_config_get_class_label(struct rspamd_classifier_config *ccf, + const char *class_name); + +/** + * Find classifier by name */ struct rspamd_classifier_config *rspamd_config_find_classifier( - struct rspamd_config *cfg, - const char *name); + struct rspamd_config *cfg, const char *name); void rspamd_ucl_add_conf_macros(struct ucl_parser *parser, struct rspamd_config *cfg); diff --git a/src/libserver/cfg_rcl.cxx b/src/libserver/cfg_rcl.cxx index 0a48e8a4f4..3f0a9606a2 100644 --- a/src/libserver/cfg_rcl.cxx +++ b/src/libserver/cfg_rcl.cxx @@ -1197,37 +1197,84 @@ rspamd_rcl_statfile_handler(rspamd_mempool_t *pool, const ucl_object_t *obj, st->opts = (ucl_object_t *) obj; st->clcf = ccf; - const auto *val = ucl_object_lookup(obj, "spam"); - if (val == nullptr) { + /* Handle migration from old 'spam' field to new 'class' field */ + const auto *class_val = ucl_object_lookup(obj, "class"); + const auto *spam_val = ucl_object_lookup(obj, "spam"); + + if (class_val != nullptr && spam_val != nullptr) { + msg_warn_config("statfile %s has both 'class' and 'spam' fields, using 'class' field", + st->symbol); + } + + if (class_val == nullptr && spam_val == nullptr) { + /* Neither field present, try to guess by symbol name */ msg_info_config( - "statfile %s has no explicit 'spam' setting, trying to guess by symbol", + "statfile %s has no explicit 'class' or 'spam' setting, trying to guess by symbol", st->symbol); if (rspamd_substring_search_caseless(st->symbol, strlen(st->symbol), "spam", 4) != -1) { st->is_spam = TRUE; + st->class_name = rspamd_mempool_strdup(pool, "spam"); } else if (rspamd_substring_search_caseless(st->symbol, strlen(st->symbol), "ham", 3) != -1) { st->is_spam = FALSE; + st->class_name = rspamd_mempool_strdup(pool, "ham"); } else { g_set_error(err, CFG_RCL_ERROR, EINVAL, - "cannot guess spam setting from %s", + "cannot guess class setting from %s, please specify 'class' field", st->symbol); return FALSE; } - msg_info_config("guessed that statfile with symbol %s is %s", - st->symbol, - st->is_spam ? "spam" : "ham"); + msg_info_config("guessed that statfile with symbol %s has class '%s'", + st->symbol, st->class_name); } + else if (class_val == nullptr && spam_val != nullptr) { + /* Only spam field present - migrate to class */ + msg_warn_config("statfile %s uses deprecated 'spam' field, please use 'class' instead", + st->symbol); + if (st->is_spam) { + st->class_name = rspamd_mempool_strdup(pool, "spam"); + } + else { + st->class_name = rspamd_mempool_strdup(pool, "ham"); + } + } + /* If class field is present, it was already parsed by the default parser */ return TRUE; } return FALSE; } +static gboolean +rspamd_rcl_class_labels_handler(rspamd_mempool_t *pool, + const ucl_object_t *obj, + const char *key, + gpointer ud, + struct rspamd_rcl_section *section, + GError **err) +{ + auto *ccf = static_cast(ud); + + if (obj->type != UCL_OBJECT) { + g_set_error(err, CFG_RCL_ERROR, EINVAL, + "class_labels must be an object"); + return FALSE; + } + + if (!rspamd_config_parse_class_labels((ucl_object_t *) obj, &ccf->class_labels)) { + g_set_error(err, CFG_RCL_ERROR, EINVAL, + "invalid class_labels configuration"); + return FALSE; + } + + return TRUE; +} + static gboolean rspamd_rcl_classifier_handler(rspamd_mempool_t *pool, const ucl_object_t *obj, @@ -1375,6 +1422,21 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool, } ccf->opts = (ucl_object_t *) obj; + + /* Validate multi-class configuration */ + GError *validation_err = nullptr; + if (!rspamd_config_validate_class_config(ccf, &validation_err)) { + if (validation_err) { + g_propagate_error(err, validation_err); + } + else { + g_set_error(err, CFG_RCL_ERROR, EINVAL, + "multi-class configuration validation failed for classifier '%s'", + ccf->name ? ccf->name : "unknown"); + } + return FALSE; + } + cfg->classifiers = g_list_prepend(cfg->classifiers, ccf); return TRUE; @@ -2504,6 +2566,18 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections) 0, "Name of classifier"); + /* + * Multi-class configuration + */ + rspamd_rcl_add_section_doc(&top, sub, + "class_labels", nullptr, + rspamd_rcl_class_labels_handler, + UCL_OBJECT, + FALSE, + TRUE, + sub->doc_ref, + "Class to backend label mapping for multi-class classification"); + /* * Statfile defaults */ @@ -2521,12 +2595,18 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections) G_STRUCT_OFFSET(struct rspamd_statfile_config, label), 0, "Statfile unique label"); + rspamd_rcl_add_default_handler(ssub, + "class", + rspamd_rcl_parse_struct_string, + G_STRUCT_OFFSET(struct rspamd_statfile_config, class_name), + 0, + "Class name for multi-class classification"); rspamd_rcl_add_default_handler(ssub, "spam", rspamd_rcl_parse_struct_boolean, G_STRUCT_OFFSET(struct rspamd_statfile_config, is_spam), 0, - "Sets if this statfile contains spam samples"); + "DEPRECATED: Sets if this statfile contains spam samples (use 'class' instead)"); } if (!(skip_sections && g_hash_table_lookup(skip_sections, "composite"))) { diff --git a/src/libserver/cfg_utils.cxx b/src/libserver/cfg_utils.cxx index c7bb202108..c8c0834397 100644 --- a/src/libserver/cfg_utils.cxx +++ b/src/libserver/cfg_utils.cxx @@ -3042,3 +3042,169 @@ rspamd_ip_is_local_cfg(struct rspamd_config *cfg, return FALSE; } + +gboolean +rspamd_config_parse_class_labels(ucl_object_t *obj, GHashTable **class_labels) +{ + const ucl_object_t *cur; + ucl_object_iter_t it = nullptr; + const char *class_name, *label; + + if (!obj || ucl_object_type(obj) != UCL_OBJECT) { + return FALSE; + } + + *class_labels = g_hash_table_new_full(g_str_hash, g_str_equal, g_free, g_free); + + while ((cur = ucl_object_iterate(obj, &it, true)) != nullptr) { + class_name = ucl_object_key(cur); + label = ucl_object_tostring(cur); + + if (class_name && label) { + /* Validate class name: alphanumeric + underscore, max 32 chars */ + if (strlen(class_name) > 32) { + msg_err("class name '%s' is too long (max 32 characters)", class_name); + g_hash_table_destroy(*class_labels); + *class_labels = nullptr; + return FALSE; + } + + for (const char *p = class_name; *p; p++) { + if (!g_ascii_isalnum(*p) && *p != '_') { + msg_err("class name '%s' contains invalid character '%c'", class_name, *p); + g_hash_table_destroy(*class_labels); + *class_labels = nullptr; + return FALSE; + } + } + + /* Validate label uniqueness */ + GHashTableIter label_iter; + gpointer key, value; + g_hash_table_iter_init(&label_iter, *class_labels); + while (g_hash_table_iter_next(&label_iter, &key, &value)) { + if (strcmp((const char *) value, label) == 0) { + msg_err("backend label '%s' is used by multiple classes", label); + g_hash_table_destroy(*class_labels); + *class_labels = nullptr; + return FALSE; + } + } + + g_hash_table_insert(*class_labels, g_strdup(class_name), g_strdup(label)); + } + } + + return g_hash_table_size(*class_labels) > 0; +} + +gboolean +rspamd_config_migrate_binary_config(struct rspamd_statfile_config *stcf) +{ + if (stcf->class_name != nullptr) { + /* Already migrated or using new format */ + return TRUE; + } + + if (stcf->is_spam) { + stcf->class_name = g_strdup("spam"); + msg_info("migrated statfile '%s' from is_spam=true to class='spam'", + stcf->symbol ? stcf->symbol : "unknown"); + } + else { + stcf->class_name = g_strdup("ham"); + msg_info("migrated statfile '%s' from is_spam=false to class='ham'", + stcf->symbol ? stcf->symbol : "unknown"); + } + + return TRUE; +} + +gboolean +rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf, GError **err) +{ + GList *cur; + GHashTable *seen_classes = nullptr; + struct rspamd_statfile_config *stcf; + unsigned int class_count = 0; + + if (!ccf || !ccf->statfiles) { + g_set_error(err, g_quark_from_static_string("config"), 1, + "classifier has no statfiles defined"); + return FALSE; + } + + seen_classes = g_hash_table_new_full(g_str_hash, g_str_equal, g_free, nullptr); + + /* Iterate through statfiles and collect classes */ + cur = ccf->statfiles; + while (cur) { + stcf = (struct rspamd_statfile_config *) cur->data; + + /* Migrate binary config if needed */ + if (!rspamd_config_migrate_binary_config(stcf)) { + g_set_error(err, g_quark_from_static_string("config"), 1, + "failed to migrate binary config for statfile '%s'", + stcf->symbol ? stcf->symbol : "unknown"); + g_hash_table_destroy(seen_classes); + return FALSE; + } + + /* Check class name */ + if (!stcf->class_name || strlen(stcf->class_name) == 0) { + g_set_error(err, g_quark_from_static_string("config"), 1, + "statfile '%s' has no class defined", + stcf->symbol ? stcf->symbol : "unknown"); + g_hash_table_destroy(seen_classes); + return FALSE; + } + + /* Track unique classes */ + if (!g_hash_table_contains(seen_classes, stcf->class_name)) { + g_hash_table_insert(seen_classes, g_strdup(stcf->class_name), GINT_TO_POINTER(1)); + class_count++; + } + + cur = g_list_next(cur); + } + + /* Validate class count */ + if (class_count < 2) { + g_set_error(err, g_quark_from_static_string("config"), 1, + "classifier must have at least 2 classes, found %u", class_count); + g_hash_table_destroy(seen_classes); + return FALSE; + } + + if (class_count > 20) { + msg_warn("classifier has %u classes, performance may be degraded above 20 classes", + class_count); + } + + /* Initialize classifier class tracking */ + if (ccf->class_names) { + g_ptr_array_unref(ccf->class_names); + } + ccf->class_names = g_ptr_array_new_with_free_func(g_free); + + /* Populate class names array */ + GHashTableIter iter; + gpointer key, value; + g_hash_table_iter_init(&iter, seen_classes); + while (g_hash_table_iter_next(&iter, &key, &value)) { + g_ptr_array_add(ccf->class_names, g_strdup((const char *) key)); + } + + g_hash_table_destroy(seen_classes); + return TRUE; +} + +const char * +rspamd_config_get_class_label(struct rspamd_classifier_config *ccf, const char *class_name) +{ + if (!ccf || !ccf->class_labels || !class_name) { + return nullptr; + } + + return (const char *) g_hash_table_lookup(ccf->class_labels, class_name); +} diff --git a/src/libserver/task.c b/src/libserver/task.c index 9f5b1f00a1..e043582846 100644 --- a/src/libserver/task.c +++ b/src/libserver/task.c @@ -730,7 +730,7 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages) if (all_done && (task->flags & RSPAMD_TASK_FLAG_LEARN_AUTO) && !RSPAMD_TASK_IS_EMPTY(task) && - !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM))) { + !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS))) { rspamd_stat_check_autolearn(task); } break; @@ -738,12 +738,32 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages) case RSPAMD_TASK_STAGE_LEARN: case RSPAMD_TASK_STAGE_LEARN_PRE: case RSPAMD_TASK_STAGE_LEARN_POST: - if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM)) { + if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS)) { if (task->err == NULL) { - if (!rspamd_stat_learn(task, - task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM, - task->cfg->lua_state, task->classifier, - st, &stat_error)) { + gboolean learn_result = FALSE; + + if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) { + /* Multi-class learning */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (autolearn_class) { + learn_result = rspamd_stat_learn_class(task, autolearn_class, + task->cfg->lua_state, task->classifier, + st, &stat_error); + } + else { + g_set_error(&stat_error, g_quark_from_static_string("stat"), 500, + "No autolearn class specified for multi-class learning"); + } + } + else { + /* Legacy binary learning */ + learn_result = rspamd_stat_learn(task, + task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM, + task->cfg->lua_state, task->classifier, + st, &stat_error); + } + + if (!learn_result) { if (stat_error == NULL) { g_set_error(&stat_error, diff --git a/src/libserver/task.h b/src/libserver/task.h index 1c1778fee4..a1742e1608 100644 --- a/src/libserver/task.h +++ b/src/libserver/task.h @@ -104,9 +104,9 @@ enum rspamd_task_stage { #define RSPAMD_TASK_FLAG_LEARN_SPAM (1u << 12u) #define RSPAMD_TASK_FLAG_LEARN_HAM (1u << 13u) #define RSPAMD_TASK_FLAG_LEARN_AUTO (1u << 14u) +#define RSPAMD_TASK_FLAG_LEARN_CLASS (1u << 25u) #define RSPAMD_TASK_FLAG_BROKEN_HEADERS (1u << 15u) -#define RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS (1u << 16u) -#define RSPAMD_TASK_FLAG_HAS_HAM_TOKENS (1u << 17u) +/* Removed RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS and RSPAMD_TASK_FLAG_HAS_HAM_TOKENS - not needed in multi-class */ #define RSPAMD_TASK_FLAG_EMPTY (1u << 18u) #define RSPAMD_TASK_FLAG_PROFILE (1u << 19u) #define RSPAMD_TASK_FLAG_GREYLISTED (1u << 20u) @@ -114,7 +114,7 @@ enum rspamd_task_stage { #define RSPAMD_TASK_FLAG_SSL (1u << 22u) #define RSPAMD_TASK_FLAG_BAD_UNICODE (1u << 23u) #define RSPAMD_TASK_FLAG_MESSAGE_REWRITE (1u << 24u) -#define RSPAMD_TASK_FLAG_MAX_SHIFT (24u) +#define RSPAMD_TASK_FLAG_MAX_SHIFT (25u) /* Request has been done by a local client */ #define RSPAMD_TASK_PROTOCOL_FLAG_LOCAL_CLIENT (1u << 1u) diff --git a/src/libstat/backends/cdb_backend.cxx b/src/libstat/backends/cdb_backend.cxx index 0f55a725c4..f6ca9c12d8 100644 --- a/src/libstat/backends/cdb_backend.cxx +++ b/src/libstat/backends/cdb_backend.cxx @@ -393,7 +393,6 @@ rspamd_cdb_process_tokens(struct rspamd_task *task, gpointer runtime) { auto *cdbp = CDB_FROM_RAW(runtime); - bool seen_values = false; for (auto i = 0u; i < tokens->len; i++) { rspamd_token_t *tok; @@ -403,21 +402,13 @@ rspamd_cdb_process_tokens(struct rspamd_task *task, if (res) { tok->values[id] = res.value(); - seen_values = true; } else { tok->values[id] = 0; } } - if (seen_values) { - if (cdbp->is_spam()) { - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; - } - else { - task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; - } - } + /* No longer need to set flags - multi-class handles missing data naturally */ return true; } @@ -488,4 +479,4 @@ void rspamd_cdb_close(gpointer ctx) { auto *cdbp = CDB_FROM_RAW(ctx); delete cdbp; -} \ No newline at end of file +} diff --git a/src/libstat/backends/mmaped_file.c b/src/libstat/backends/mmaped_file.c index 4430bb9a43..a6423a1e6c 100644 --- a/src/libstat/backends/mmaped_file.c +++ b/src/libstat/backends/mmaped_file.c @@ -85,8 +85,7 @@ typedef struct { #define RSPAMD_STATFILE_VERSION \ { \ - '1', '2' \ - } + '1', '2'} #define BACKUP_SUFFIX ".old" static void rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool, @@ -958,12 +957,7 @@ rspamd_mmaped_file_process_tokens(struct rspamd_task *task, GPtrArray *tokens, tok->values[id] = rspamd_mmaped_file_get_block(mf, h1, h2); } - if (mf->cf->is_spam) { - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; - } - else { - task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; - } + /* No longer need to set flags - multi-class handles missing data naturally */ return TRUE; } diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index 7137904e99..01ed818c47 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -121,9 +121,9 @@ public: } static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded, - bool is_spam) -> std::optional *> + const char *class_label) -> std::optional *> { - auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + auto var_name = fmt::format("{}_{}", redis_object_expanded, class_label); auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str()); if (res) { @@ -158,9 +158,9 @@ public: return true; } - auto save_in_mempool(bool is_spam) const + auto save_in_mempool(const char *class_label) const { - auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + auto var_name = fmt::format("{}_{}", redis_object_expanded, class_label); /* We do not set destructor for the variable, as it should be already added on creation */ rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr); msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str()); @@ -177,6 +177,26 @@ rspamd_redis_stat_quark(void) return g_quark_from_static_string(M); } +/* + * Get the class label for a statfile (for multi-class support) + */ +static const char * +get_class_label(struct rspamd_statfile_config *stcf) +{ + /* Try to get the label from the classifier config first */ + if (stcf->clcf && stcf->clcf->class_labels && stcf->class_name) { + const char *label = rspamd_config_get_class_label(stcf->clcf, stcf->class_name); + if (label) { + return label; + } + /* If no label mapping found, use class name directly */ + return stcf->class_name; + } + + /* Fallback to legacy binary classification */ + return stcf->is_spam ? "S" : "H"; +} + /* * Non-static for lua unit testing */ @@ -541,7 +561,7 @@ rspamd_redis_init(struct rspamd_stat_ctx *ctx, ucl_object_push_lua(L, st->classifier->cfg->opts, false); ucl_object_push_lua(L, st->stcf->opts, false); lua_pushstring(L, backend->stcf->symbol); - lua_pushboolean(L, backend->stcf->is_spam); + lua_pushstring(L, get_class_label(backend->stcf)); /* Pass class label instead of boolean */ /* Push event loop if there is one available (e.g. we are not in rspamadm mode) */ if (ctx->event_loop) { @@ -607,10 +627,12 @@ rspamd_redis_runtime(struct rspamd_task *task, return nullptr; } + const char *class_label = get_class_label(stcf); + /* Look for the cached results */ if (!learn) { auto maybe_existing = redis_stat_runtime::maybe_recover_from_mempool(task, - object_expanded, stcf->is_spam); + object_expanded, class_label); if (maybe_existing) { auto *rt = maybe_existing.value(); @@ -624,24 +646,45 @@ rspamd_redis_runtime(struct rspamd_task *task, /* No cached result (or learn), create new one */ auto *rt = new redis_stat_runtime(ctx, task, object_expanded); - if (!learn) { + if (!learn && stcf->clcf && stcf->clcf->class_names && stcf->clcf->class_names->len > 2) { /* - * For check, we also need to create the opposite class runtime to avoid - * double call for Redis scripts. - * This runtime will be filled later. + * For multi-class classification, we need to create runtimes for ALL classes + * to avoid multiple Redis calls. The actual Redis call will fetch data for all classes. */ + GList *cur = stcf->clcf->statfiles; + while (cur) { + auto *other_stcf = (struct rspamd_statfile_config *) cur->data; + if (other_stcf != stcf) { + const char *other_label = get_class_label(other_stcf); + + auto maybe_other_rt = redis_stat_runtime::maybe_recover_from_mempool(task, + object_expanded, other_label); + if (!maybe_other_rt) { + auto *other_rt = new redis_stat_runtime(ctx, task, object_expanded); + other_rt->save_in_mempool(other_label); + other_rt->need_redis_call = false; + } + } + cur = g_list_next(cur); + } + } + else if (!learn) { + /* + * For binary classification, create the opposite class runtime to avoid + * double call for Redis scripts (backward compatibility). + */ + const char *opposite_label = stcf->is_spam ? "H" : "S"; auto maybe_opposite_rt = redis_stat_runtime::maybe_recover_from_mempool(task, - object_expanded, - !stcf->is_spam); + object_expanded, opposite_label); if (!maybe_opposite_rt) { auto *opposite_rt = new redis_stat_runtime(ctx, task, object_expanded); - opposite_rt->save_in_mempool(!stcf->is_spam); + opposite_rt->save_in_mempool(opposite_label); opposite_rt->need_redis_call = false; } } - rt->save_in_mempool(stcf->is_spam); + rt->save_in_mempool(class_label); return rt; } @@ -823,16 +866,10 @@ rspamd_redis_classified(lua_State *L) bool result = lua_toboolean(L, 2); if (result) { - /* Indexes: - * 3 - learned_ham (int) - * 4 - learned_spam (int) - * 5 - ham_tokens (pair) - * 6 - spam_tokens (pair) + /* Check if this is binary format [learned_ham, learned_spam, ham_tokens, spam_tokens] + * or multi-class format [learned_counts_table, outputs_table] */ - /* - * We need to fill our runtime AND the opposite runtime - */ auto filler_func = [](redis_stat_runtime *rt, lua_State *L, unsigned learned, int tokens_pos) { rt->learned = learned; redis_stat_runtime::result_type *res; @@ -854,32 +891,96 @@ rspamd_redis_classified(lua_State *L) rt->set_results(res); }; - auto opposite_rt_maybe = redis_stat_runtime::maybe_recover_from_mempool(task, - rt->redis_object_expanded, - !rt->stcf->is_spam); + /* Check if result[3] is a number (binary) or table (multi-class) */ + lua_rawgeti(L, 3, 1); /* Get first element of result array */ + bool is_binary_format = lua_isnumber(L, -1); + lua_pop(L, 1); - if (!opposite_rt_maybe) { - msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie); + if (is_binary_format) { + /* Binary format: [learned_ham, learned_spam, ham_tokens, spam_tokens] */ - return 0; - } + /* Find the opposite runtime for binary classification compatibility */ + const char *opposite_label; + if (rt->stcf->class_name) { + /* Multi-class: find a different class (simplified for now) */ + opposite_label = strcmp(get_class_label(rt->stcf), "S") == 0 ? "H" : "S"; + } + else { + /* Binary: use opposite spam/ham */ + opposite_label = rt->stcf->is_spam ? "H" : "S"; + } + auto opposite_rt_maybe = redis_stat_runtime::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + opposite_label); + + if (!opposite_rt_maybe) { + msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie); + return 0; + } + + if (rt->stcf->is_spam) { + filler_func(rt, L, lua_tointeger(L, 4), 6); + filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5); + } + else { + filler_func(rt, L, lua_tointeger(L, 3), 5); + filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6); + } - if (rt->stcf->is_spam) { - filler_func(rt, L, lua_tointeger(L, 4), 6); - filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5); + /* Process all tokens */ + g_assert(rt->tokens != nullptr); + rt->process_tokens(rt->tokens); + opposite_rt_maybe.value()->process_tokens(rt->tokens); } else { - filler_func(rt, L, lua_tointeger(L, 3), 5); - filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6); - } + /* Multi-class format: [learned_counts_table, outputs_table] */ + + /* Get learned counts table (index 3) and outputs table (index 4) */ + lua_rawgeti(L, 3, 1); /* learned_counts */ + lua_rawgeti(L, 3, 2); /* outputs */ + + /* Iterate through all class labels to fill all runtimes */ + if (rt->stcf->clcf && rt->stcf->clcf->class_labels) { + GHashTableIter iter; + gpointer key, value; + g_hash_table_iter_init(&iter, rt->stcf->clcf->class_labels); + + while (g_hash_table_iter_next(&iter, &key, &value)) { + const char *class_label = (const char *) value; + + /* Find runtime for this class */ + auto class_rt_maybe = redis_stat_runtime::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + class_label); + + if (class_rt_maybe) { + auto *class_rt = class_rt_maybe.value(); + + /* Get learned count for this class */ + lua_pushstring(L, class_label); + lua_gettable(L, -3); /* learned_counts[class_label] */ + unsigned learned = lua_tointeger(L, -1); + lua_pop(L, 1); + + /* Get outputs for this class */ + lua_pushstring(L, class_label); + lua_gettable(L, -2); /* outputs[class_label] */ + int outputs_pos = lua_gettop(L); + + filler_func(class_rt, L, learned, outputs_pos); + lua_pop(L, 1); + } + } + } + + lua_pop(L, 2); /* Pop learned_counts and outputs tables */ - /* Mark task as being processed */ - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; + /* Process tokens for all runtimes */ + g_assert(rt->tokens != nullptr); + rt->process_tokens(rt->tokens); + } - /* Process all tokens */ - g_assert(rt->tokens != nullptr); - rt->process_tokens(rt->tokens); - opposite_rt_maybe.value()->process_tokens(rt->tokens); + /* Tokens processed - no need to set flags in multi-class approach */ } else { /* Error message is on index 3 */ @@ -929,7 +1030,25 @@ rspamd_redis_process_tokens(struct rspamd_task *task, rspamd_lua_task_push(L, task); lua_pushstring(L, rt->redis_object_expanded); lua_pushinteger(L, id); - lua_pushboolean(L, rt->stcf->is_spam); + + /* Send all class labels for multi-class support */ + if (rt->stcf->clcf && rt->stcf->clcf->class_labels && g_hash_table_size(rt->stcf->clcf->class_labels) > 0) { + /* Multi-class: send array of all class labels */ + lua_newtable(L); + GHashTableIter iter; + gpointer key, value; + int idx = 1; + g_hash_table_iter_init(&iter, rt->stcf->clcf->class_labels); + while (g_hash_table_iter_next(&iter, &key, &value)) { + lua_pushstring(L, (const char *) value); /* Use the label, not class name */ + lua_rawseti(L, -2, idx++); + } + } + else { + /* Binary compatibility: send current class label as single string */ + lua_pushstring(L, get_class_label(rt->stcf)); + } + lua_new_text(L, tokens_buf, tokens_len, false); /* Store rt in random cookie */ @@ -1028,7 +1147,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, rspamd_lua_task_push(L, task); lua_pushstring(L, rt->redis_object_expanded); lua_pushinteger(L, id); - lua_pushboolean(L, rt->stcf->is_spam); + lua_pushstring(L, get_class_label(rt->stcf)); /* Pass class label instead of boolean */ lua_pushstring(L, rt->stcf->symbol); /* Detect unlearn */ diff --git a/src/libstat/backends/sqlite3_backend.c b/src/libstat/backends/sqlite3_backend.c index 973dc30a76..8f29a3b4ed 100644 --- a/src/libstat/backends/sqlite3_backend.c +++ b/src/libstat/backends/sqlite3_backend.c @@ -589,12 +589,7 @@ rspamd_sqlite3_process_tokens(struct rspamd_task *task, } } - if (rt->cf->is_spam) { - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; - } - else { - task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; - } + /* No longer need to set flags - multi-class handles missing data naturally */ } diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c index 93b5149dad..4a1b0cf32a 100644 --- a/src/libstat/classifiers/bayes.c +++ b/src/libstat/classifiers/bayes.c @@ -94,8 +94,8 @@ inv_chi_square(struct rspamd_task *task, double value, int freedom_deg) } struct bayes_task_closure { - double ham_prob; - double spam_prob; + double ham_prob; /* Kept for binary compatibility */ + double spam_prob; /* Kept for binary compatibility */ double meta_skip_prob; uint64_t processed_tokens; uint64_t total_hits; @@ -103,6 +103,20 @@ struct bayes_task_closure { struct rspamd_task *task; }; +/* Multi-class classification closure */ +struct bayes_multiclass_closure { + double *class_log_probs; /* Array of log probabilities for each class */ + uint64_t *class_learns; /* Learning counts for each class */ + char **class_names; /* Array of class names */ + unsigned int num_classes; /* Number of classes */ + double meta_skip_prob; + uint64_t processed_tokens; + uint64_t total_hits; + uint64_t text_tokens; + struct rspamd_task *task; + struct rspamd_classifier_config *cfg; +}; + /* * Mathematically we use pow(complexity, complexity), where complexity is the * window index @@ -248,6 +262,301 @@ bayes_classify_token(struct rspamd_classifier *ctx, } } +/* + * Multinomial token classification for multi-class Bayes + */ +static void +bayes_classify_token_multiclass(struct rspamd_classifier *ctx, + rspamd_token_t *tok, + struct bayes_multiclass_closure *cl) +{ + unsigned int i, j; + int id; + struct rspamd_statfile *st; + struct rspamd_task *task; + const char *token_type = "txt"; + double val, fw, w; + unsigned int *class_counts; + unsigned int total_count = 0; + + task = cl->task; + + /* Skip meta tokens probabilistically if configured */ + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) { + val = rspamd_random_double_fast(); + if (val <= cl->meta_skip_prob) { + return; + } + token_type = "meta"; + } + + /* Allocate array for class counts */ + class_counts = g_alloca(cl->num_classes * sizeof(unsigned int)); + memset(class_counts, 0, cl->num_classes * sizeof(unsigned int)); + + /* Collect counts for each class */ + for (i = 0; i < ctx->statfiles_ids->len; i++) { + id = g_array_index(ctx->statfiles_ids, int, i); + st = g_ptr_array_index(ctx->ctx->statfiles, id); + g_assert(st != NULL); + val = tok->values[id]; + + if (val > 0) { + /* Find which class this statfile belongs to */ + for (j = 0; j < cl->num_classes; j++) { + if (st->stcf->class_name && + strcmp(st->stcf->class_name, cl->class_names[j]) == 0) { + class_counts[j] += val; + total_count += val; + cl->total_hits += val; + break; + } + } + } + } + + /* Calculate multinomial probability for this token */ + if (total_count >= ctx->cfg->min_token_hits) { + /* Feature weight calculation */ + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) { + fw = 1.0; + } + else { + fw = feature_weight[tok->window_idx % G_N_ELEMENTS(feature_weight)]; + } + + w = (fw * total_count) / (1.0 + fw * total_count); + + /* Apply multinomial model for each class */ + for (j = 0; j < cl->num_classes; j++) { + double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]); + double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes); + + /* Skip probabilities too close to uniform (1/num_classes) */ + double uniform_prior = 1.0 / cl->num_classes; + if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) { + continue; + } + + cl->class_log_probs[j] += log(class_prob); + } + + cl->processed_tokens++; + if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) { + cl->text_tokens++; + } + + if (tok->t1 && tok->t2) { + msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %.3f, total_count: %ud, " + "processed for %u classes", + token_type, tok->data, + (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, + (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, + fw, total_count, cl->num_classes); + } + } +} + +/* + * Multinomial Bayes classification with Fisher confidence + */ +static gboolean +bayes_classify_multiclass(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task) +{ + struct bayes_multiclass_closure cl; + rspamd_token_t *tok; + unsigned int i, j, text_tokens = 0; + int id; + struct rspamd_statfile *st; + rspamd_multiclass_result_t *result; + double *normalized_probs; + double max_log_prob = -INFINITY; + unsigned int winning_class_idx = 0; + double confidence; + + g_assert(ctx != NULL); + g_assert(tokens != NULL); + + /* Initialize multi-class closure */ + memset(&cl, 0, sizeof(cl)); + cl.task = task; + cl.cfg = ctx->cfg; + + /* Get class information from classifier config */ + if (!ctx->cfg->class_names || ctx->cfg->class_names->len < 2) { + msg_debug_bayes("insufficient classes for multiclass classification"); + return TRUE; /* Fall back to binary mode */ + } + + cl.num_classes = ctx->cfg->class_names->len; + cl.class_names = (char **) ctx->cfg->class_names->pdata; + cl.class_log_probs = g_alloca(cl.num_classes * sizeof(double)); + cl.class_learns = g_alloca(cl.num_classes * sizeof(uint64_t)); + + /* Initialize probabilities and get learning counts */ + for (i = 0; i < cl.num_classes; i++) { + cl.class_log_probs[i] = 0.0; + cl.class_learns[i] = 0; + } + + /* Collect learning counts for each class */ + for (i = 0; i < ctx->statfiles_ids->len; i++) { + id = g_array_index(ctx->statfiles_ids, int, i); + st = g_ptr_array_index(ctx->ctx->statfiles, id); + g_assert(st != NULL); + + for (j = 0; j < cl.num_classes; j++) { + if (st->stcf->class_name && + strcmp(st->stcf->class_name, cl.class_names[j]) == 0) { + cl.class_learns[j] += st->backend->total_learns(task, + g_ptr_array_index(task->stat_runtimes, id), ctx->ctx); + break; + } + } + } + + /* Check minimum learns requirement */ + if (ctx->cfg->min_learns > 0) { + for (i = 0; i < cl.num_classes; i++) { + if (cl.class_learns[i] < ctx->cfg->min_learns) { + msg_info_task("not classified as %s. The class needs more " + "training samples. Currently: %ul; minimum %ud required", + cl.class_names[i], cl.class_learns[i], ctx->cfg->min_learns); + return TRUE; + } + } + } + + /* Count text tokens */ + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) { + text_tokens++; + } + } + + if (text_tokens == 0) { + msg_info_task("skipped classification as there are no text tokens. " + "Total tokens: %ud", + tokens->len); + return TRUE; + } + + /* Set meta token skip probability */ + if (text_tokens > tokens->len - text_tokens) { + cl.meta_skip_prob = 0.0; + } + else { + cl.meta_skip_prob = 1.0 - (double) text_tokens / tokens->len; + } + + /* Process all tokens */ + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + bayes_classify_token_multiclass(ctx, tok, &cl); + } + + if (cl.processed_tokens == 0) { + msg_info_bayes("no tokens found in bayes database " + "(%ud total tokens, %ud text tokens), ignore stats", + tokens->len, text_tokens); + return TRUE; + } + + if (ctx->cfg->min_tokens > 0 && + cl.text_tokens < (int) (ctx->cfg->min_tokens * 0.1)) { + msg_info_bayes("ignore bayes probability since we have " + "found too few text tokens: %uL (of %ud checked), " + "at least %d required", + cl.text_tokens, text_tokens, + (int) (ctx->cfg->min_tokens * 0.1)); + return TRUE; + } + + /* Normalize probabilities using softmax */ + normalized_probs = g_alloca(cl.num_classes * sizeof(double)); + + /* Find maximum for numerical stability */ + for (i = 0; i < cl.num_classes; i++) { + if (cl.class_log_probs[i] > max_log_prob) { + max_log_prob = cl.class_log_probs[i]; + winning_class_idx = i; + } + } + + /* Apply softmax normalization */ + double sum_exp = 0.0; + for (i = 0; i < cl.num_classes; i++) { + normalized_probs[i] = exp(cl.class_log_probs[i] - max_log_prob); + sum_exp += normalized_probs[i]; + } + + if (sum_exp > 0) { + for (i = 0; i < cl.num_classes; i++) { + normalized_probs[i] /= sum_exp; + } + } + else { + /* Fallback to uniform distribution */ + for (i = 0; i < cl.num_classes; i++) { + normalized_probs[i] = 1.0 / cl.num_classes; + } + } + + /* Calculate confidence using Fisher method for the winning class */ + if (max_log_prob > -300) { + confidence = 1.0 - inv_chi_square(task, max_log_prob, cl.processed_tokens); + } + else { + confidence = normalized_probs[winning_class_idx]; + } + + /* Create and store multiclass result */ + result = g_new0(rspamd_multiclass_result_t, 1); + result->class_names = g_new(char *, cl.num_classes); + result->probabilities = g_new(double, cl.num_classes); + result->num_classes = cl.num_classes; + result->winning_class = cl.class_names[winning_class_idx]; /* Reference, not copy */ + result->confidence = confidence; + + for (i = 0; i < cl.num_classes; i++) { + result->class_names[i] = g_strdup(cl.class_names[i]); + result->probabilities[i] = normalized_probs[i]; + } + + rspamd_task_set_multiclass_result(task, result); + + /* Insert symbol for winning class if confidence is significant */ + if (confidence > 0.05) { + char sumbuf[32]; + double final_prob = rspamd_normalize_probability(confidence, 0.5); + + rspamd_snprintf(sumbuf, sizeof(sumbuf), "%.2f%%", confidence * 100.0); + + /* Find the statfile for the winning class to get the symbol */ + for (i = 0; i < ctx->statfiles_ids->len; i++) { + id = g_array_index(ctx->statfiles_ids, int, i); + st = g_ptr_array_index(ctx->ctx->statfiles, id); + + if (st->stcf->class_name && + strcmp(st->stcf->class_name, cl.class_names[winning_class_idx]) == 0) { + rspamd_task_insert_result(task, st->stcf->symbol, final_prob, sumbuf); + break; + } + } + + msg_debug_bayes("multiclass classification: winning class '%s' with " + "probability %.3f, confidence %.3f, %uL tokens processed", + cl.class_names[winning_class_idx], + normalized_probs[winning_class_idx], + confidence, cl.processed_tokens); + } + + return TRUE; +} + gboolean bayes_init(struct rspamd_config *cfg, @@ -279,6 +588,28 @@ bayes_classify(struct rspamd_classifier *ctx, g_assert(ctx != NULL); g_assert(tokens != NULL); + /* Check if this is a multi-class classifier */ + if (ctx->cfg->class_names && ctx->cfg->class_names->len >= 2) { + /* Verify that at least one statfile has class_name set (indicating new multi-class config) */ + gboolean has_class_names = FALSE; + for (i = 0; i < ctx->statfiles_ids->len; i++) { + int id = g_array_index(ctx->statfiles_ids, int, i); + struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id); + if (st->stcf->class_name) { + has_class_names = TRUE; + break; + } + } + + if (has_class_names) { + msg_debug_bayes("using multiclass classification with %u classes", + (unsigned int) ctx->cfg->class_names->len); + return bayes_classify_multiclass(ctx, tokens, task); + } + } + + /* Fall back to binary classification */ + msg_debug_bayes("using binary classification"); memset(&cl, 0, sizeof(cl)); cl.task = task; @@ -549,3 +880,152 @@ bayes_learn_spam(struct rspamd_classifier *ctx, return TRUE; } + +gboolean +bayes_learn_class(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task, + const char *class_name, + gboolean unlearn, + GError **err) +{ + unsigned int i, j, total_cnt; + int id; + struct rspamd_statfile *st; + rspamd_token_t *tok; + gboolean incrementing; + unsigned int *class_counts = NULL; + struct rspamd_statfile **class_statfiles = NULL; + unsigned int num_classes = 0; + + g_assert(ctx != NULL); + g_assert(tokens != NULL); + g_assert(class_name != NULL); + + incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND; + + /* Count classes and prepare arrays for multi-class learning */ + if (ctx->cfg->class_names && ctx->cfg->class_names->len > 0) { + num_classes = ctx->cfg->class_names->len; + class_counts = g_alloca(num_classes * sizeof(unsigned int)); + class_statfiles = g_alloca(num_classes * sizeof(struct rspamd_statfile *)); + memset(class_counts, 0, num_classes * sizeof(unsigned int)); + memset(class_statfiles, 0, num_classes * sizeof(struct rspamd_statfile *)); + } + + for (i = 0; i < tokens->len; i++) { + total_cnt = 0; + tok = g_ptr_array_index(tokens, i); + + /* Reset class counts for this token */ + if (num_classes > 0) { + memset(class_counts, 0, num_classes * sizeof(unsigned int)); + } + + for (j = 0; j < ctx->statfiles_ids->len; j++) { + id = g_array_index(ctx->statfiles_ids, int, j); + st = g_ptr_array_index(ctx->ctx->statfiles, id); + g_assert(st != NULL); + + /* Determine if this statfile matches our target class */ + gboolean is_target_class = FALSE; + if (st->stcf->class_name) { + /* Multi-class: exact class name match */ + is_target_class = (strcmp(st->stcf->class_name, class_name) == 0); + } + else { + /* Legacy binary: map class_name to spam/ham */ + if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) { + is_target_class = st->stcf->is_spam; + } + else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) { + is_target_class = !st->stcf->is_spam; + } + } + + if (is_target_class) { + /* Learning: increment the target class */ + if (incrementing) { + tok->values[id] = 1; + } + else { + tok->values[id]++; + } + total_cnt += tok->values[id]; + + /* Track class counts for debugging */ + if (num_classes > 0) { + for (unsigned int k = 0; k < num_classes; k++) { + const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k); + if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) { + class_counts[k] += tok->values[id]; + class_statfiles[k] = st; + break; + } + } + } + } + else { + /* Unlearning: decrement other classes if unlearn flag is set */ + if (tok->values[id] > 0 && unlearn) { + if (incrementing) { + tok->values[id] = -1; + } + else { + tok->values[id]--; + } + total_cnt += tok->values[id]; + + /* Track class counts for debugging */ + if (num_classes > 0) { + for (unsigned int k = 0; k < num_classes; k++) { + const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k); + if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) { + class_counts[k] += tok->values[id]; + class_statfiles[k] = st; + break; + } + } + } + } + else if (incrementing) { + tok->values[id] = 0; + } + } + } + + /* Debug logging */ + if (tok->t1 && tok->t2) { + if (num_classes > 0) { + GString *debug_str = g_string_new(""); + for (unsigned int k = 0; k < num_classes; k++) { + const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k); + g_string_append_printf(debug_str, "%s:%d ", check_class, class_counts[k]); + } + msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, " + "class_counts: %s", + tok->data, + (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, + (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, + tok->window_idx, total_cnt, debug_str->str); + g_string_free(debug_str, TRUE); + } + else { + msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, " + "class: %s", + tok->data, + (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, + (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, + tok->window_idx, total_cnt, class_name); + } + } + else { + msg_debug_bayes("token %uL : window: %d, total_count: %d, " + "class: %s", + tok->data, + tok->window_idx, total_cnt, class_name); + } + } + + return TRUE; +} diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h index 22978e6734..cab658146e 100644 --- a/src/libstat/classifiers/classifiers.h +++ b/src/libstat/classifiers/classifiers.h @@ -54,6 +54,13 @@ struct rspamd_stat_classifier { gboolean unlearn, GError **err); + gboolean (*learn_class_func)(struct rspamd_classifier *ctx, + GPtrArray *input, + struct rspamd_task *task, + const char *class_name, + gboolean unlearn, + GError **err); + void (*fin_func)(struct rspamd_classifier *cl); }; @@ -73,6 +80,13 @@ gboolean bayes_learn_spam(struct rspamd_classifier *ctx, gboolean unlearn, GError **err); +gboolean bayes_learn_class(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task, + const char *class_name, + gboolean unlearn, + GError **err); + void bayes_fin(struct rspamd_classifier *); /* Generic lua classifier */ diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h index 811566ad38..aa6111a8b2 100644 --- a/src/libstat/stat_api.h +++ b/src/libstat/stat_api.h @@ -107,6 +107,23 @@ rspamd_stat_result_t rspamd_stat_learn(struct rspamd_task *task, unsigned int stage, GError **err); +/** + * Learn task as a specific class, task must be processed prior to this call + * @param task task to learn + * @param class_name name of the class to learn (e.g., "spam", "ham", "transactional") + * @param L lua state + * @param classifier NULL to learn all classifiers, name to learn a specific one + * @param stage learning stage + * @param err error returned + * @return TRUE if task has been learned + */ +rspamd_stat_result_t rspamd_stat_learn_class(struct rspamd_task *task, + const char *class_name, + lua_State *L, + const char *classifier, + unsigned int stage, + GError **err); + /** * Get the overall statistics for all statfile backends * @param cfg configuration @@ -120,6 +137,43 @@ rspamd_stat_result_t rspamd_stat_statistics(struct rspamd_task *task, void rspamd_stat_unload(void); +/** + * Multi-class classification result structure + */ +typedef struct { + char **class_names; /**< Array of class names */ + double *probabilities; /**< Array of probabilities for each class */ + unsigned int num_classes; /**< Number of classes */ + const char *winning_class; /**< Name of the winning class (reference, not owned) */ + double confidence; /**< Confidence of the winning class */ +} rspamd_multiclass_result_t; + +/** + * Set multi-class classification result for a task + */ +void rspamd_task_set_multiclass_result(struct rspamd_task *task, + rspamd_multiclass_result_t *result); + +/** + * Get multi-class classification result from a task + */ +rspamd_multiclass_result_t *rspamd_task_get_multiclass_result(struct rspamd_task *task); + +/** + * Free multi-class result structure + */ +void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result); + +/** + * Set autolearn class for a task + */ +void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name); + +/** + * Get autolearn class from a task + */ +const char *rspamd_task_get_autolearn_class(struct rspamd_task *task); + #ifdef __cplusplus } #endif diff --git a/src/libstat/stat_config.c b/src/libstat/stat_config.c index 8a5313df22..5ada7d4681 100644 --- a/src/libstat/stat_config.c +++ b/src/libstat/stat_config.c @@ -28,6 +28,7 @@ static struct rspamd_stat_classifier lua_classifier = { .init_func = lua_classifier_init, .classify_func = lua_classifier_classify, .learn_spam_func = lua_classifier_learn_spam, + .learn_class_func = NULL, /* TODO: implement lua multi-class learning */ .fin_func = NULL, }; @@ -37,6 +38,7 @@ static struct rspamd_stat_classifier stat_classifiers[] = { .init_func = bayes_init, .classify_func = bayes_classify, .learn_spam_func = bayes_learn_spam, + .learn_class_func = bayes_learn_class, .fin_func = bayes_fin, }}; @@ -68,8 +70,7 @@ static struct rspamd_stat_tokenizer stat_tokenizers[] = { .dec_learns = rspamd_##eltn##_dec_learns, \ .get_stat = rspamd_##eltn##_get_stat, \ .load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \ - .close = rspamd_##eltn##_close \ - } + .close = rspamd_##eltn##_close} #define RSPAMD_STAT_BACKEND_ELT_READONLY(nam, eltn) \ { \ .name = #nam, \ @@ -85,8 +86,7 @@ static struct rspamd_stat_tokenizer stat_tokenizers[] = { .dec_learns = NULL, \ .get_stat = rspamd_##eltn##_get_stat, \ .load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \ - .close = rspamd_##eltn##_close \ - } + .close = rspamd_##eltn##_close} static struct rspamd_stat_backend stat_backends[] = { RSPAMD_STAT_BACKEND_ELT(mmap, mmaped_file), @@ -101,8 +101,7 @@ static struct rspamd_stat_backend stat_backends[] = { .runtime = rspamd_stat_cache_##eltn##_runtime, \ .check = rspamd_stat_cache_##eltn##_check, \ .learn = rspamd_stat_cache_##eltn##_learn, \ - .close = rspamd_stat_cache_##eltn##_close \ - } + .close = rspamd_stat_cache_##eltn##_close} static struct rspamd_stat_cache stat_caches[] = { RSPAMD_STAT_CACHE_ELT(sqlite3, sqlite3), diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 176064087b..5126fd2cc3 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -32,6 +32,78 @@ static const double similarity_threshold = 80.0; +void rspamd_task_set_multiclass_result(struct rspamd_task *task, rspamd_multiclass_result_t *result) +{ + g_assert(task != NULL); + g_assert(result != NULL); + + rspamd_mempool_set_variable(task->task_pool, "multiclass_bayes_result", result, + (rspamd_mempool_destruct_t) rspamd_multiclass_result_free); +} + +rspamd_multiclass_result_t * +rspamd_task_get_multiclass_result(struct rspamd_task *task) +{ + g_assert(task != NULL); + + return (rspamd_multiclass_result_t *) rspamd_mempool_get_variable(task->task_pool, + "multiclass_bayes_result"); +} + +void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result) +{ + if (result == NULL) { + return; + } + + g_free(result->class_names); + g_free(result->probabilities); + /* winning_class is a reference, not owned - don't free */ + g_free(result); +} + +void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name) +{ + g_assert(task != NULL); + g_assert(class_name != NULL); + + /* Store the class name in the mempool */ + const char *class_name_copy = rspamd_mempool_strdup(task->task_pool, class_name); + rspamd_mempool_set_variable(task->task_pool, "autolearn_class", + (gpointer) class_name_copy, NULL); + + /* Set the appropriate flags */ + task->flags |= RSPAMD_TASK_FLAG_LEARN_CLASS; + + /* For backward compatibility, also set binary flags */ + if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + } + else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + } +} + +const char * +rspamd_task_get_autolearn_class(struct rspamd_task *task) +{ + g_assert(task != NULL); + + if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) { + return (const char *) rspamd_mempool_get_variable(task->task_pool, "autolearn_class"); + } + + /* Fallback to binary flags for backward compatibility */ + if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) { + return "spam"; + } + else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) { + return "ham"; + } + + return NULL; +} + static void rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, struct rspamd_task *task) @@ -394,18 +466,9 @@ rspamd_stat_classifiers_process(struct rspamd_stat_ctx *st_ctx, } /* - * Do not classify a message if some class is missing + * Multi-class approach: don't check for missing classes + * Missing tokens naturally result in 0 probability */ - if (!(task->flags & RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS)) { - msg_info_task("skip statistics as SPAM class is missing"); - - return; - } - if (!(task->flags & RSPAMD_TASK_FLAG_HAS_HAM_TOKENS)) { - msg_info_task("skip statistics as HAM class is missing"); - - return; - } for (i = 0; i < st_ctx->classifiers->len; i++) { cl = g_ptr_array_index(st_ctx->classifiers, i); @@ -565,7 +628,24 @@ rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx, if (sel->cache && sel->cachecf) { rt = cl->cache->runtime(task, sel->cachecf, FALSE); - learn_res = cl->cache->check(task, spam, rt); + + /* For multi-class learning, determine spam boolean from class name if available */ + gboolean cache_spam = spam; /* Default to original spam parameter */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (autolearn_class) { + if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) { + cache_spam = TRUE; + } + else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) { + cache_spam = FALSE; + } + else { + /* For other classes, use a heuristic or default to spam for cache purposes */ + cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */ + } + } + + learn_res = cl->cache->check(task, cache_spam, rt); } if (learn_res == RSPAMD_LEARN_IGNORE) { @@ -658,9 +738,63 @@ rspamd_stat_classifiers_learn(struct rspamd_stat_ctx *st_ctx, continue; } - if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam, - task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { - learned = TRUE; + /* Check if classifier supports multi-class learning and if we should use it */ + if (cl->subrs->learn_class_func && cl->cfg->class_names && cl->cfg->class_names->len > 2) { + /* Multi-class learning: determine class name from task flags or autolearn result */ + const char *class_name = NULL; + + if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) { + /* Find spam class name */ + for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) { + const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k); + /* Look for statfile with this class that is spam */ + GList *cur = cl->cfg->statfiles; + while (cur) { + struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && stcf->is_spam) { + class_name = check_class; + break; + } + cur = g_list_next(cur); + } + if (class_name) break; + } + if (!class_name) class_name = "spam"; /* fallback */ + } + else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) { + /* Find ham class name */ + for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) { + const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k); + /* Look for statfile with this class that is ham */ + GList *cur = cl->cfg->statfiles; + while (cur) { + struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && !stcf->is_spam) { + class_name = check_class; + break; + } + cur = g_list_next(cur); + } + if (class_name) break; + } + if (!class_name) class_name = "ham"; /* fallback */ + } + else { + /* Fallback to spam/ham based on the spam parameter */ + class_name = spam ? "spam" : "ham"; + } + + if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name, + task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { + learned = TRUE; + } + } + else { + /* Binary learning: use existing function */ + if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam, + task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { + learned = TRUE; + } } } @@ -870,7 +1004,24 @@ rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx, if (cl->cache) { cache_run = cl->cache->runtime(task, cl->cachecf, TRUE); - cl->cache->learn(task, spam, cache_run); + + /* For multi-class learning, determine spam boolean from class name if available */ + gboolean cache_spam = spam; /* Default to original spam parameter */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (autolearn_class) { + if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) { + cache_spam = TRUE; + } + else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) { + cache_spam = FALSE; + } + else { + /* For other classes, use a heuristic or default to spam for cache purposes */ + cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */ + } + } + + cl->cache->learn(task, cache_spam, cache_run); } } @@ -879,6 +1030,218 @@ rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx, return res; } +static gboolean +rspamd_stat_classifiers_learn_class(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const char *classifier, + const char *class_name, + GError **err) +{ + struct rspamd_classifier *cl, *sel = NULL; + unsigned int i; + gboolean learned = FALSE, too_small = FALSE, too_large = FALSE; + + if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL && + *err == NULL) { + /* Do not learn twice */ + g_set_error(err, rspamd_stat_quark(), 208, "<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + class_name); + + return FALSE; + } + + /* Check whether we have learned that file */ + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) { + continue; + } + + sel = cl; + + /* Now check max and min tokens */ + if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) { + msg_info_task( + "<%s> contains less tokens than required for %s classifier: " + "%ud < %ud", + MESSAGE_FIELD(task, message_id), + cl->cfg->name, + task->tokens->len, + cl->cfg->min_tokens); + too_small = TRUE; + continue; + } + else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) { + msg_info_task( + "<%s> contains more tokens than allowed for %s classifier: " + "%ud > %ud", + MESSAGE_FIELD(task, message_id), + cl->cfg->name, + task->tokens->len, + cl->cfg->max_tokens); + too_large = TRUE; + continue; + } + + /* Use the new multi-class learning function if available */ + if (cl->subrs->learn_class_func) { + if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name, + task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { + learned = TRUE; + } + } + else { + /* Fallback to binary learning with class name mapping */ + gboolean is_spam; + if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) { + is_spam = TRUE; + } + else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) { + is_spam = FALSE; + } + else { + /* For unknown classes with binary classifier, skip */ + msg_info_task("skipping class '%s' for binary classifier %s", + class_name, cl->cfg->name); + continue; + } + + if (cl->subrs->learn_spam_func(cl, task->tokens, task, is_spam, + task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { + learned = TRUE; + } + } + } + + if (sel == NULL) { + if (classifier) { + g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier " + "with name %s", + classifier); + } + else { + g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined"); + } + + return FALSE; + } + + if (!learned && err && *err == NULL) { + if (too_large) { + g_set_error(err, rspamd_stat_quark(), 204, + "<%s> contains more tokens than allowed for %s classifier: " + "%d > %d", + MESSAGE_FIELD(task, message_id), + sel->cfg->name, + task->tokens->len, + sel->cfg->max_tokens); + } + else if (too_small) { + g_set_error(err, rspamd_stat_quark(), 204, + "<%s> contains less tokens than required for %s classifier: " + "%d < %d", + MESSAGE_FIELD(task, message_id), + sel->cfg->name, + task->tokens->len, + sel->cfg->min_tokens); + } + } + + return learned; +} + +rspamd_stat_result_t +rspamd_stat_learn_class(struct rspamd_task *task, + const char *class_name, + lua_State *L, + const char *classifier, + unsigned int stage, + GError **err) +{ + struct rspamd_stat_ctx *st_ctx; + rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK; + + /* + * We assume now that a task has been already classified before + * coming to learn + */ + g_assert(RSPAMD_TASK_IS_CLASSIFIED(task)); + + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + msg_debug_bayes("learn class stage %d has been called for class '%s'", stage, class_name); + + if (st_ctx->classifiers->len == 0) { + msg_debug_bayes("no classifiers defined"); + task->processed_stages |= stage; + return ret; + } + + if (task->message == NULL) { + ret = RSPAMD_STAT_PROCESS_ERROR; + if (err && *err == NULL) { + g_set_error(err, rspamd_stat_quark(), 500, + "Trying to learn an empty message"); + } + + task->processed_stages |= stage; + return ret; + } + + if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) { + /* Process classifiers - determine spam boolean for compatibility */ + gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0); + rspamd_stat_preprocess(st_ctx, task, TRUE, spam); + + if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) { + msg_debug_bayes("cache check failed, skip learning"); + return RSPAMD_STAT_PROCESS_ERROR; + } + } + else if (stage == RSPAMD_TASK_STAGE_LEARN) { + /* Process classifiers */ + if (!rspamd_stat_classifiers_learn_class(st_ctx, task, classifier, + class_name, err)) { + if (err && *err == NULL) { + g_set_error(err, rspamd_stat_quark(), 500, + "Unknown statistics error, found when learning classifiers;" + " classifier: %s", + task->classifier); + } + return RSPAMD_STAT_PROCESS_ERROR; + } + + /* Process backends - determine spam boolean for compatibility */ + gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0); + if (!rspamd_stat_backends_learn(st_ctx, task, classifier, spam, err)) { + if (err && *err == NULL) { + g_set_error(err, rspamd_stat_quark(), 500, + "Unknown statistics error, found when storing data on backend;" + " classifier: %s", + task->classifier); + } + return RSPAMD_STAT_PROCESS_ERROR; + } + } + else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) { + /* Process backends - determine spam boolean for compatibility */ + gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0); + if (!rspamd_stat_backends_post_learn(st_ctx, task, classifier, spam, err)) { + return RSPAMD_STAT_PROCESS_ERROR; + } + } + + task->processed_stages |= stage; + + return ret; +} + rspamd_stat_result_t rspamd_stat_learn(struct rspamd_task *task, gboolean spam, lua_State *L, const char *classifier, unsigned int stage, @@ -1039,12 +1402,11 @@ rspamd_stat_check_autolearn(struct rspamd_task *task) if (mres) { if (mres->score > rspamd_task_get_required_score(task, mres)) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; - + rspamd_task_set_autolearn_class(task, "spam"); ret = TRUE; } else if (mres->score < 0) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + rspamd_task_set_autolearn_class(task, "ham"); ret = TRUE; } } @@ -1076,12 +1438,11 @@ rspamd_stat_check_autolearn(struct rspamd_task *task) if (mres) { if (mres->score >= spam_score) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; - + rspamd_task_set_autolearn_class(task, "spam"); ret = TRUE; } else if (mres->score <= ham_score) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + rspamd_task_set_autolearn_class(task, "ham"); ret = TRUE; } } @@ -1117,11 +1478,16 @@ rspamd_stat_check_autolearn(struct rspamd_task *task) /* We can have immediate results */ if (lua_ret) { if (strcmp(lua_ret, "ham") == 0) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + rspamd_task_set_autolearn_class(task, "ham"); ret = TRUE; } else if (strcmp(lua_ret, "spam") == 0) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + rspamd_task_set_autolearn_class(task, "spam"); + ret = TRUE; + } + else { + /* Multi-class: any other class name */ + rspamd_task_set_autolearn_class(task, lua_ret); ret = TRUE; } } @@ -1139,79 +1505,138 @@ rspamd_stat_check_autolearn(struct rspamd_task *task) } } else if (ucl_object_type(obj) == UCL_OBJECT) { - /* Try to find autolearn callback */ - if (cl->autolearn_cbref == 0) { - /* We don't have preprocessed cb id, so try to get it */ - if (!rspamd_lua_require_function(L, "lua_bayes_learn", - "autolearn")) { - msg_err_task("cannot get autolearn library from " - "`lua_bayes_learn`"); - } - else { - cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + /* Check if this is a multi-class autolearn configuration */ + const ucl_object_t *multiclass_obj = ucl_object_lookup(obj, "multiclass"); + + if (multiclass_obj && ucl_object_type(multiclass_obj) == UCL_OBJECT) { + /* Multi-class threshold-based autolearn */ + const ucl_object_t *thresholds_obj = ucl_object_lookup(multiclass_obj, "thresholds"); + + if (thresholds_obj && ucl_object_type(thresholds_obj) == UCL_OBJECT) { + /* Iterate through class thresholds */ + ucl_object_iter_t it = NULL; + const ucl_object_t *class_obj; + const char *class_name; + + while ((class_obj = ucl_object_iterate(thresholds_obj, &it, true))) { + class_name = ucl_object_key(class_obj); + + if (class_name && ucl_object_type(class_obj) == UCL_ARRAY && class_obj->len == 2) { + /* [min_score, max_score] for this class */ + const ucl_object_t *min_elt = ucl_array_find_index(class_obj, 0); + const ucl_object_t *max_elt = ucl_array_find_index(class_obj, 1); + + if ((ucl_object_type(min_elt) == UCL_FLOAT || ucl_object_type(min_elt) == UCL_INT) && + (ucl_object_type(max_elt) == UCL_FLOAT || ucl_object_type(max_elt) == UCL_INT)) { + + double min_score = ucl_object_todouble(min_elt); + double max_score = ucl_object_todouble(max_elt); + + if (mres && mres->score >= min_score && mres->score <= max_score) { + rspamd_task_set_autolearn_class(task, class_name); + ret = TRUE; + msg_debug_bayes("multiclass autolearn: score %.2f matches class '%s' [%.2f, %.2f]", + mres->score, class_name, min_score, max_score); + break; /* Stop at first matching class */ + } + } + } + } } } - - if (cl->autolearn_cbref != -1) { - lua_pushcfunction(L, &rspamd_lua_traceback); - err_idx = lua_gettop(L); - lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref); - - ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); - *ptask = task; - rspamd_lua_setclass(L, rspamd_task_classname, -1); - /* Push the whole object as well */ - ucl_object_push_lua(L, obj, true); - - if (lua_pcall(L, 2, 1, err_idx) != 0) { - msg_err_task("call to autolearn script failed: " - "%s", - lua_tostring(L, -1)); + else { + /* Try to find autolearn callback */ + if (cl->autolearn_cbref == 0) { + /* We don't have preprocessed cb id, so try to get it */ + if (!rspamd_lua_require_function(L, "lua_bayes_learn", + "autolearn")) { + msg_err_task("cannot get autolearn library from " + "`lua_bayes_learn`"); + } + else { + cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } } - else { - lua_ret = lua_tostring(L, -1); - if (lua_ret) { - if (strcmp(lua_ret, "ham") == 0) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; - ret = TRUE; - } - else if (strcmp(lua_ret, "spam") == 0) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; - ret = TRUE; + if (cl->autolearn_cbref != -1) { + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref); + + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass(L, rspamd_task_classname, -1); + /* Push the whole object as well */ + ucl_object_push_lua(L, obj, true); + + if (lua_pcall(L, 2, 1, err_idx) != 0) { + msg_err_task("call to autolearn script failed: " + "%s", + lua_tostring(L, -1)); + } + else { + lua_ret = lua_tostring(L, -1); + + if (lua_ret) { + if (strcmp(lua_ret, "ham") == 0) { + rspamd_task_set_autolearn_class(task, "ham"); + ret = TRUE; + } + else if (strcmp(lua_ret, "spam") == 0) { + rspamd_task_set_autolearn_class(task, "spam"); + ret = TRUE; + } + else { + /* Multi-class: any other class name */ + rspamd_task_set_autolearn_class(task, lua_ret); + ret = TRUE; + } } } - } - lua_settop(L, err_idx - 1); + lua_settop(L, err_idx - 1); + } } - } - if (ret) { - /* Do not autolearn if we have this symbol already */ - if (rspamd_stat_has_classifier_symbols(task, mres, cl)) { - ret = FALSE; - task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM | - RSPAMD_TASK_FLAG_LEARN_SPAM); - } - else if (mres != NULL) { - if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) { - msg_info_task("<%s>: autolearn ham for classifier " - "'%s' as message's " - "score is negative: %.2f", - MESSAGE_FIELD(task, message_id), cl->cfg->name, - mres->score); - } - else { - msg_info_task("<%s>: autolearn spam for classifier " - "'%s' as message's " - "action is reject, score: %.2f", - MESSAGE_FIELD(task, message_id), cl->cfg->name, - mres->score); + if (ret) { + /* Do not autolearn if we have this symbol already */ + if (rspamd_stat_has_classifier_symbols(task, mres, cl)) { + ret = FALSE; + task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM | + RSPAMD_TASK_FLAG_LEARN_SPAM | + RSPAMD_TASK_FLAG_LEARN_CLASS); + /* Clear the autolearn class from mempool */ + rspamd_mempool_set_variable(task->task_pool, "autolearn_class", NULL, NULL); } + else if (mres != NULL) { + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + + if (autolearn_class) { + if (strcmp(autolearn_class, "ham") == 0) { + msg_info_task("<%s>: autolearn ham for classifier " + "'%s' as message's " + "score is negative: %.2f", + MESSAGE_FIELD(task, message_id), cl->cfg->name, + mres->score); + } + else if (strcmp(autolearn_class, "spam") == 0) { + msg_info_task("<%s>: autolearn spam for classifier " + "'%s' as message's " + "action is reject, score: %.2f", + MESSAGE_FIELD(task, message_id), cl->cfg->name, + mres->score); + } + else { + msg_info_task("<%s>: autolearn class '%s' for classifier " + "'%s', score: %.2f", + MESSAGE_FIELD(task, message_id), autolearn_class, + cl->cfg->name, mres->score); + } + } - task->classifier = cl->cfg->name; - break; + task->classifier = cl->cfg->name; + break; + } } } }