local N = "bayes"
local function gen_classify_functor(redis_params, classify_script_id)
- return function(task, expanded_key, id, is_spam, stat_tokens, callback)
-
+ return function(task, expanded_key, id, class_labels, stat_tokens, callback)
local function classify_redis_cb(err, data)
lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data)
if err then
callback(task, false, err)
else
- callback(task, true, data[1], data[2], data[3], data[4])
+ -- Handle both binary and multi-class results
+ if type(data[1]) == "table" then
+ -- Multi-class format: [learned_counts_table, outputs_table]
+ -- Convert to binary format for backward compatibility if needed
+ local learned_counts = data[1]
+ local outputs = data[2]
+
+ -- For now, return ham/spam data if available for backward compatibility
+ local learned_ham = learned_counts["H"] or learned_counts["ham"] or 0
+ local learned_spam = learned_counts["S"] or learned_counts["spam"] or 0
+ local output_ham = outputs["H"] or outputs["ham"] or {}
+ local output_spam = outputs["S"] or outputs["spam"] or {}
+
+ callback(task, true, learned_ham, learned_spam, output_ham, output_spam)
+ else
+ -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
+ callback(task, true, data[1], data[2], data[3], data[4])
+ end
+ end
+ end
+
+ -- Determine class labels to send to Redis script
+ local script_class_labels
+ if type(class_labels) == "table" then
+ script_class_labels = class_labels
+ else
+ -- Single class label or boolean compatibility
+ if class_labels == true or class_labels == "true" then
+ script_class_labels = "S" -- spam
+ elseif class_labels == false or class_labels == "false" then
+ script_class_labels = "H" -- ham
+ else
+ script_class_labels = class_labels -- string class label
end
end
lua_redis.exec_redis_script(classify_script_id,
- { task = task, is_write = false, key = expanded_key },
- classify_redis_cb, { expanded_key, stat_tokens })
+ { task = task, is_write = false, key = expanded_key },
+ classify_redis_cb, { expanded_key, script_class_labels, stat_tokens })
end
end
local function gen_learn_functor(redis_params, learn_script_id)
- return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens)
+ return function(task, expanded_key, id, class_label, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens)
local function learn_redis_cb(err, data)
lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data)
if err then
end
end
+ -- Convert class_label for backward compatibility
+ local script_class_label = class_label
+ if class_label == true or class_label == "true" then
+ script_class_label = "S" -- spam
+ elseif class_label == false or class_label == "false" then
+ script_class_label = "H" -- ham
+ end
+
if maybe_text_tokens then
lua_redis.exec_redis_script(learn_script_id,
- { task = task, is_write = true, key = expanded_key },
- learn_redis_cb,
- { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
+ { task = task, is_write = true, key = expanded_key },
+ learn_redis_cb,
+ { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
else
lua_redis.exec_redis_script(learn_script_id,
- { task = task, is_write = true, key = expanded_key },
- learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
+ { task = task, is_write = true, key = expanded_key },
+ learn_redis_cb, { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens })
end
-
end
end
--- @param classifier_ucl ucl of the classifier config
--- @param statfile_ucl ucl of the statfile config
--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
-exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)
-
+exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, class_label, ev_base, stat_periodic_cb)
local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
if not redis_params then
if ev_base then
rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
-
local function stat_redis_cb(err, data)
lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)
end
end
+ -- Convert class_label to learn key
+ local learn_key
+ if class_label == true or class_label == "true" or class_label == "S" then
+ learn_key = "learns_spam"
+ elseif class_label == false or class_label == "false" or class_label == "H" then
+ learn_key = "learns_ham"
+ else
+ -- For other class labels, use learns_<class_label>
+ learn_key = "learns_" .. string.lower(tostring(class_label))
+ end
+
lua_redis.exec_redis_script(stat_script_id,
- { ev_base = ev_base, cfg = cfg, is_write = false },
- stat_redis_cb, { tostring(cursor),
- symbol,
- is_spam and "learns_spam" or "learns_ham",
- tostring(max_users) })
+ { ev_base = ev_base, cfg = cfg, is_write = false },
+ stat_redis_cb, { tostring(cursor),
+ symbol,
+ learn_key,
+ tostring(max_users) })
return statfile_ucl.monitor_timeout or classifier_ucl.monitor_timeout or 30.0
end)
end
local function gen_cache_check_functor(redis_params, check_script_id, conf)
local packed_conf = ucl.to_format(conf, 'msgpack')
return function(task, cache_id, callback)
-
local function classify_redis_cb(err, data)
lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data))
if err then
lua_util.debugm(N, task, 'checking cache: %s', cache_id)
lua_redis.exec_redis_script(check_script_id,
- { task = task, is_write = false, key = cache_id },
- classify_redis_cb, { cache_id, packed_conf })
+ { task = task, is_write = false, key = cache_id },
+ classify_redis_cb, { cache_id, packed_conf })
end
end
local function gen_cache_learn_functor(redis_params, learn_script_id, conf)
local packed_conf = ucl.to_format(conf, 'msgpack')
- return function(task, cache_id, is_spam)
+ return function(task, cache_id, class_name)
local function learn_redis_cb(err, data)
lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
end
- lua_util.debugm(N, task, 'try to learn cache: %s', cache_id)
- lua_redis.exec_redis_script(learn_script_id,
- { task = task, is_write = true, key = cache_id },
- learn_redis_cb,
- { cache_id, is_spam and "1" or "0", packed_conf })
+ -- Handle backward compatibility for boolean values
+ local cache_class_name = class_name
+ if type(class_name) == "boolean" then
+ cache_class_name = class_name and "spam" or "ham"
+ elseif class_name == true or class_name == "true" then
+ cache_class_name = "spam"
+ elseif class_name == false or class_name == "false" then
+ cache_class_name = "ham"
+ end
+ lua_util.debugm(N, task, 'try to learn cache: %s as %s', cache_id, cache_class_name)
+ lua_redis.exec_redis_script(learn_script_id,
+ { task = task, is_write = true, key = cache_id },
+ learn_redis_cb,
+ { cache_id, cache_class_name, packed_conf })
end
end
local default_conf = {
cache_prefix = "learned_ids",
cache_max_elt = 10000, -- Maximum number of elements in the cache key
- cache_max_keys = 5, -- Maximum number of keys in the cache
- cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value)
+ cache_max_keys = 5, -- Maximum number of keys in the cache
+ cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value)
}
local conf = lua_util.override_defaults(default_conf, classifier_ucl)
local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params)
return gen_cache_check_functor(redis_params, check_script_id, conf), gen_cache_learn_functor(redis_params,
- learn_script_id, conf)
+ learn_script_id, conf)
end
return exports
--- Lua script to perform cache checking for bayes classification
+-- Lua script to perform cache checking for bayes classification (multi-class)
-- This script accepts the following parameters:
-- key1 - cache id
--- key2 - is spam (1 or 0)
+-- key2 - class name (e.g. "spam", "ham", "transactional")
-- key3 - configuration table in message pack
local cache_id = KEYS[1]
-local is_spam = KEYS[2]
+local class_name = KEYS[2]
local conf = cmsgpack.unpack(KEYS[3])
+
+-- Handle backward compatibility for binary values
+if class_name == "1" then
+ class_name = "spam"
+elseif class_name == "0" then
+ class_name = "ham"
+end
cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
-- Try each prefix that is in Redis (as some other instance might have set it)
local have = redis.call('HGET', prefix, cache_id)
if have then
- -- Already in cache, but is_spam changes when relearning
- redis.call('HSET', prefix, cache_id, is_spam)
+ -- Already in cache, but class_name changes when relearning
+ redis.call('HSET', prefix, cache_id, class_name)
return false
end
end
if count < lim then
-- We can add it to this prefix
- redis.call('HSET', prefix, cache_id, is_spam)
+ redis.call('HSET', prefix, cache_id, class_name)
added = true
end
end
if exists then
if not expired then
redis.call('DEL', prefix)
- redis.call('HSET', prefix, cache_id, is_spam)
+ redis.call('HSET', prefix, cache_id, class_name)
-- Do not expire anything else
expired = true
--- Lua script to perform bayes classification
+-- Lua script to perform bayes classification (multi-class)
-- This script accepts the following parameters:
-- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - set of tokens encoded in messagepack array of strings
+-- key2 - class labels: either table of all class labels (multi-class) or single string (binary)
+-- key3 - set of tokens encoded in messagepack array of strings
local prefix = KEYS[1]
-local output_spam = {}
-local output_ham = {}
+local class_labels_arg = KEYS[2]
+local input_tokens = cmsgpack.unpack(KEYS[3])
-local learned_ham = tonumber(redis.call('HGET', prefix, 'learns_ham')) or 0
-local learned_spam = tonumber(redis.call('HGET', prefix, 'learns_spam')) or 0
+-- Determine if this is multi-class (table) or binary (string)
+local class_labels = {}
+if type(class_labels_arg) == "table" then
+ class_labels = class_labels_arg
+else
+ -- Binary compatibility: handle old boolean or single string format
+ if class_labels_arg == "true" then
+ class_labels = { "S" } -- spam
+ elseif class_labels_arg == "false" then
+ class_labels = { "H" } -- ham
+ else
+ class_labels = { class_labels_arg } -- single class label
+ end
+end
--- Output is a set of pairs (token_index, token_count), tokens that are not
--- found are not filled.
--- This optimisation will save a lot of space for sparse tokens, and in Bayes that assumption is normally held
+-- Get learned counts for all classes
+local learned_counts = {}
+for _, label in ipairs(class_labels) do
+ local key = 'learns_' .. string.lower(label)
+ -- Also try legacy keys for backward compatibility
+ if label == 'H' then
+ key = 'learns_ham'
+ elseif label == 'S' then
+ key = 'learns_spam'
+ end
+ learned_counts[label] = tonumber(redis.call('HGET', prefix, key)) or 0
+end
-if learned_ham > 0 and learned_spam > 0 then
- local input_tokens = cmsgpack.unpack(KEYS[2])
- for i, token in ipairs(input_tokens) do
- local token_data = redis.call('HMGET', token, 'H', 'S')
+-- Get token data for all classes (only if we have learns for any class)
+local outputs = {}
+local has_learns = false
+for _, count in pairs(learned_counts) do
+ if count > 0 then
+ has_learns = true
+ break
+ end
+end
- if token_data then
- local ham_count = token_data[1]
- local spam_count = token_data[2]
+if has_learns then
+ -- Initialize outputs for each class
+ for _, label in ipairs(class_labels) do
+ outputs[label] = {}
+ end
- if ham_count then
- table.insert(output_ham, { i, tonumber(ham_count) })
- end
+ -- Process each token
+ for i, token in ipairs(input_tokens) do
+ local token_data = redis.call('HMGET', token, unpack(class_labels))
- if spam_count then
- table.insert(output_spam, { i, tonumber(spam_count) })
+ if token_data then
+ for j, label in ipairs(class_labels) do
+ local count = token_data[j]
+ if count then
+ table.insert(outputs[label], { i, tonumber(count) })
+ end
end
end
end
end
-return { learned_ham, learned_spam, output_ham, output_spam }
\ No newline at end of file
+-- Format output for backward compatibility
+if #class_labels == 2 and class_labels[1] == 'H' and class_labels[2] == 'S' then
+ -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
+ return {
+ learned_counts['H'] or 0,
+ learned_counts['S'] or 0,
+ outputs['H'] or {},
+ outputs['S'] or {}
+ }
+elseif #class_labels == 2 and class_labels[1] == 'S' and class_labels[2] == 'H' then
+ -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
+ return {
+ learned_counts['H'] or 0,
+ learned_counts['S'] or 0,
+ outputs['H'] or {},
+ outputs['S'] or {}
+ }
+else
+ -- Multi-class format: [learned_counts_table, outputs_table]
+ return { learned_counts, outputs }
+end
--- Lua script to perform bayes learning
+-- Lua script to perform bayes learning (multi-class)
-- This script accepts the following parameters:
-- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - boolean is_spam
+-- key2 - class label string (e.g. "S", "H", "T")
-- key3 - string symbol
-- key4 - boolean is_unlearn
-- key5 - set of tokens encoded in messagepack array of strings
-- key6 - set of text tokens (if any) encoded in messagepack array of strings (size must be twice of `KEYS[5]`)
local prefix = KEYS[1]
-local is_spam = KEYS[2] == 'true' and true or false
+local class_label = KEYS[2]
local symbol = KEYS[3]
local is_unlearn = KEYS[4] == 'true' and true or false
local input_tokens = cmsgpack.unpack(KEYS[5])
text_tokens = cmsgpack.unpack(KEYS[6])
end
-local hash_key = is_spam and 'S' or 'H'
-local learned_key = is_spam and 'learns_spam' or 'learns_ham'
+-- Handle backward compatibility for boolean values
+if class_label == 'true' then
+ class_label = 'S' -- spam
+elseif class_label == 'false' then
+ class_label = 'H' -- ham
+end
+
+local hash_key = class_label
+local learned_key = 'learns_' .. string.lower(class_label)
+
+-- Handle legacy keys for backward compatibility
+if class_label == 'S' then
+ learned_key = 'learns_spam'
+elseif class_label == 'H' then
+ learned_key = 'learns_ham'
+end
redis.call('SADD', symbol .. '_keys', prefix)
-redis.call('HSET', prefix, 'version', '2') -- new schema
+redis.call('HSET', prefix, 'version', '2') -- new schema
redis.call('HINCRBY', prefix, learned_key, is_unlearn and -1 or 1) -- increase or decrease learned count
for i, token in ipairs(input_tokens) do
char *symbol; /**< symbol of statfile */
char *label; /**< label of this statfile */
ucl_object_t *opts; /**< other options */
- gboolean is_spam; /**< spam flag */
+ char *class_name; /**< class name for multi-class classification */
+ gboolean is_spam; /**< DEPRECATED: spam flag - use class_name instead */
struct rspamd_classifier_config *clcf; /**< parent pointer of classifier configuration */
gpointer data; /**< opaque data */
};
double min_prob_strength; /**< use only tokens with probability in [0.5 - MPS, 0.5 + MPS] */
unsigned int min_learns; /**< minimum number of learns for each statfile */
unsigned int flags;
+ GHashTable *class_labels; /**< class_name -> backend_symbol mapping for multi-class */
+ GPtrArray *class_names; /**< ordered list of class names */
};
struct rspamd_worker_bind_conf {
*/
gboolean rspamd_config_check_statfiles(struct rspamd_classifier_config *cf);
-/*
- * Find classifier config by name
+/**
+ * Multi-class configuration helpers
+ */
+gboolean rspamd_config_parse_class_labels(ucl_object_t *obj,
+ GHashTable **class_labels);
+
+gboolean rspamd_config_migrate_binary_config(struct rspamd_statfile_config *stcf);
+
+gboolean rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf,
+ GError **err);
+
+const char *rspamd_config_get_class_label(struct rspamd_classifier_config *ccf,
+ const char *class_name);
+
+/**
+ * Find classifier by name
*/
struct rspamd_classifier_config *rspamd_config_find_classifier(
- struct rspamd_config *cfg,
- const char *name);
+ struct rspamd_config *cfg, const char *name);
void rspamd_ucl_add_conf_macros(struct ucl_parser *parser,
struct rspamd_config *cfg);
st->opts = (ucl_object_t *) obj;
st->clcf = ccf;
- const auto *val = ucl_object_lookup(obj, "spam");
- if (val == nullptr) {
+ /* Handle migration from old 'spam' field to new 'class' field */
+ const auto *class_val = ucl_object_lookup(obj, "class");
+ const auto *spam_val = ucl_object_lookup(obj, "spam");
+
+ if (class_val != nullptr && spam_val != nullptr) {
+ msg_warn_config("statfile %s has both 'class' and 'spam' fields, using 'class' field",
+ st->symbol);
+ }
+
+ if (class_val == nullptr && spam_val == nullptr) {
+ /* Neither field present, try to guess by symbol name */
msg_info_config(
- "statfile %s has no explicit 'spam' setting, trying to guess by symbol",
+ "statfile %s has no explicit 'class' or 'spam' setting, trying to guess by symbol",
st->symbol);
if (rspamd_substring_search_caseless(st->symbol,
strlen(st->symbol), "spam", 4) != -1) {
st->is_spam = TRUE;
+ st->class_name = rspamd_mempool_strdup(pool, "spam");
}
else if (rspamd_substring_search_caseless(st->symbol,
strlen(st->symbol), "ham", 3) != -1) {
st->is_spam = FALSE;
+ st->class_name = rspamd_mempool_strdup(pool, "ham");
}
else {
g_set_error(err,
CFG_RCL_ERROR,
EINVAL,
- "cannot guess spam setting from %s",
+ "cannot guess class setting from %s, please specify 'class' field",
st->symbol);
return FALSE;
}
- msg_info_config("guessed that statfile with symbol %s is %s",
- st->symbol,
- st->is_spam ? "spam" : "ham");
+ msg_info_config("guessed that statfile with symbol %s has class '%s'",
+ st->symbol, st->class_name);
}
+ else if (class_val == nullptr && spam_val != nullptr) {
+ /* Only spam field present - migrate to class */
+ msg_warn_config("statfile %s uses deprecated 'spam' field, please use 'class' instead",
+ st->symbol);
+ if (st->is_spam) {
+ st->class_name = rspamd_mempool_strdup(pool, "spam");
+ }
+ else {
+ st->class_name = rspamd_mempool_strdup(pool, "ham");
+ }
+ }
+ /* If class field is present, it was already parsed by the default parser */
return TRUE;
}
return FALSE;
}
+static gboolean
+rspamd_rcl_class_labels_handler(rspamd_mempool_t *pool,
+ const ucl_object_t *obj,
+ const char *key,
+ gpointer ud,
+ struct rspamd_rcl_section *section,
+ GError **err)
+{
+ auto *ccf = static_cast<rspamd_classifier_config *>(ud);
+
+ if (obj->type != UCL_OBJECT) {
+ g_set_error(err, CFG_RCL_ERROR, EINVAL,
+ "class_labels must be an object");
+ return FALSE;
+ }
+
+ if (!rspamd_config_parse_class_labels((ucl_object_t *) obj, &ccf->class_labels)) {
+ g_set_error(err, CFG_RCL_ERROR, EINVAL,
+ "invalid class_labels configuration");
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
static gboolean
rspamd_rcl_classifier_handler(rspamd_mempool_t *pool,
const ucl_object_t *obj,
}
ccf->opts = (ucl_object_t *) obj;
+
+ /* Validate multi-class configuration */
+ GError *validation_err = nullptr;
+ if (!rspamd_config_validate_class_config(ccf, &validation_err)) {
+ if (validation_err) {
+ g_propagate_error(err, validation_err);
+ }
+ else {
+ g_set_error(err, CFG_RCL_ERROR, EINVAL,
+ "multi-class configuration validation failed for classifier '%s'",
+ ccf->name ? ccf->name : "unknown");
+ }
+ return FALSE;
+ }
+
cfg->classifiers = g_list_prepend(cfg->classifiers, ccf);
return TRUE;
0,
"Name of classifier");
+ /*
+ * Multi-class configuration
+ */
+ rspamd_rcl_add_section_doc(&top, sub,
+ "class_labels", nullptr,
+ rspamd_rcl_class_labels_handler,
+ UCL_OBJECT,
+ FALSE,
+ TRUE,
+ sub->doc_ref,
+ "Class to backend label mapping for multi-class classification");
+
/*
* Statfile defaults
*/
G_STRUCT_OFFSET(struct rspamd_statfile_config, label),
0,
"Statfile unique label");
+ rspamd_rcl_add_default_handler(ssub,
+ "class",
+ rspamd_rcl_parse_struct_string,
+ G_STRUCT_OFFSET(struct rspamd_statfile_config, class_name),
+ 0,
+ "Class name for multi-class classification");
rspamd_rcl_add_default_handler(ssub,
"spam",
rspamd_rcl_parse_struct_boolean,
G_STRUCT_OFFSET(struct rspamd_statfile_config, is_spam),
0,
- "Sets if this statfile contains spam samples");
+ "DEPRECATED: Sets if this statfile contains spam samples (use 'class' instead)");
}
if (!(skip_sections && g_hash_table_lookup(skip_sections, "composite"))) {
return FALSE;
}
+
+gboolean
+rspamd_config_parse_class_labels(ucl_object_t *obj, GHashTable **class_labels)
+{
+ const ucl_object_t *cur;
+ ucl_object_iter_t it = nullptr;
+ const char *class_name, *label;
+
+ if (!obj || ucl_object_type(obj) != UCL_OBJECT) {
+ return FALSE;
+ }
+
+ *class_labels = g_hash_table_new_full(g_str_hash, g_str_equal, g_free, g_free);
+
+ while ((cur = ucl_object_iterate(obj, &it, true)) != nullptr) {
+ class_name = ucl_object_key(cur);
+ label = ucl_object_tostring(cur);
+
+ if (class_name && label) {
+ /* Validate class name: alphanumeric + underscore, max 32 chars */
+ if (strlen(class_name) > 32) {
+ msg_err("class name '%s' is too long (max 32 characters)", class_name);
+ g_hash_table_destroy(*class_labels);
+ *class_labels = nullptr;
+ return FALSE;
+ }
+
+ for (const char *p = class_name; *p; p++) {
+ if (!g_ascii_isalnum(*p) && *p != '_') {
+ msg_err("class name '%s' contains invalid character '%c'", class_name, *p);
+ g_hash_table_destroy(*class_labels);
+ *class_labels = nullptr;
+ return FALSE;
+ }
+ }
+
+ /* Validate label uniqueness */
+ GHashTableIter label_iter;
+ gpointer key, value;
+ g_hash_table_iter_init(&label_iter, *class_labels);
+ while (g_hash_table_iter_next(&label_iter, &key, &value)) {
+ if (strcmp((const char *) value, label) == 0) {
+ msg_err("backend label '%s' is used by multiple classes", label);
+ g_hash_table_destroy(*class_labels);
+ *class_labels = nullptr;
+ return FALSE;
+ }
+ }
+
+ g_hash_table_insert(*class_labels, g_strdup(class_name), g_strdup(label));
+ }
+ }
+
+ return g_hash_table_size(*class_labels) > 0;
+}
+
+gboolean
+rspamd_config_migrate_binary_config(struct rspamd_statfile_config *stcf)
+{
+ if (stcf->class_name != nullptr) {
+ /* Already migrated or using new format */
+ return TRUE;
+ }
+
+ if (stcf->is_spam) {
+ stcf->class_name = g_strdup("spam");
+ msg_info("migrated statfile '%s' from is_spam=true to class='spam'",
+ stcf->symbol ? stcf->symbol : "unknown");
+ }
+ else {
+ stcf->class_name = g_strdup("ham");
+ msg_info("migrated statfile '%s' from is_spam=false to class='ham'",
+ stcf->symbol ? stcf->symbol : "unknown");
+ }
+
+ return TRUE;
+}
+
+gboolean
+rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf, GError **err)
+{
+ GList *cur;
+ GHashTable *seen_classes = nullptr;
+ struct rspamd_statfile_config *stcf;
+ unsigned int class_count = 0;
+
+ if (!ccf || !ccf->statfiles) {
+ g_set_error(err, g_quark_from_static_string("config"), 1,
+ "classifier has no statfiles defined");
+ return FALSE;
+ }
+
+ seen_classes = g_hash_table_new_full(g_str_hash, g_str_equal, g_free, nullptr);
+
+ /* Iterate through statfiles and collect classes */
+ cur = ccf->statfiles;
+ while (cur) {
+ stcf = (struct rspamd_statfile_config *) cur->data;
+
+ /* Migrate binary config if needed */
+ if (!rspamd_config_migrate_binary_config(stcf)) {
+ g_set_error(err, g_quark_from_static_string("config"), 1,
+ "failed to migrate binary config for statfile '%s'",
+ stcf->symbol ? stcf->symbol : "unknown");
+ g_hash_table_destroy(seen_classes);
+ return FALSE;
+ }
+
+ /* Check class name */
+ if (!stcf->class_name || strlen(stcf->class_name) == 0) {
+ g_set_error(err, g_quark_from_static_string("config"), 1,
+ "statfile '%s' has no class defined",
+ stcf->symbol ? stcf->symbol : "unknown");
+ g_hash_table_destroy(seen_classes);
+ return FALSE;
+ }
+
+ /* Track unique classes */
+ if (!g_hash_table_contains(seen_classes, stcf->class_name)) {
+ g_hash_table_insert(seen_classes, g_strdup(stcf->class_name), GINT_TO_POINTER(1));
+ class_count++;
+ }
+
+ cur = g_list_next(cur);
+ }
+
+ /* Validate class count */
+ if (class_count < 2) {
+ g_set_error(err, g_quark_from_static_string("config"), 1,
+ "classifier must have at least 2 classes, found %u", class_count);
+ g_hash_table_destroy(seen_classes);
+ return FALSE;
+ }
+
+ if (class_count > 20) {
+ msg_warn("classifier has %u classes, performance may be degraded above 20 classes",
+ class_count);
+ }
+
+ /* Initialize classifier class tracking */
+ if (ccf->class_names) {
+ g_ptr_array_unref(ccf->class_names);
+ }
+ ccf->class_names = g_ptr_array_new_with_free_func(g_free);
+
+ /* Populate class names array */
+ GHashTableIter iter;
+ gpointer key, value;
+ g_hash_table_iter_init(&iter, seen_classes);
+ while (g_hash_table_iter_next(&iter, &key, &value)) {
+ g_ptr_array_add(ccf->class_names, g_strdup((const char *) key));
+ }
+
+ g_hash_table_destroy(seen_classes);
+ return TRUE;
+}
+
+const char *
+rspamd_config_get_class_label(struct rspamd_classifier_config *ccf, const char *class_name)
+{
+ if (!ccf || !ccf->class_labels || !class_name) {
+ return nullptr;
+ }
+
+ return (const char *) g_hash_table_lookup(ccf->class_labels, class_name);
+}
if (all_done && (task->flags & RSPAMD_TASK_FLAG_LEARN_AUTO) &&
!RSPAMD_TASK_IS_EMPTY(task) &&
- !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM))) {
+ !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS))) {
rspamd_stat_check_autolearn(task);
}
break;
case RSPAMD_TASK_STAGE_LEARN:
case RSPAMD_TASK_STAGE_LEARN_PRE:
case RSPAMD_TASK_STAGE_LEARN_POST:
- if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM)) {
+ if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS)) {
if (task->err == NULL) {
- if (!rspamd_stat_learn(task,
- task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM,
- task->cfg->lua_state, task->classifier,
- st, &stat_error)) {
+ gboolean learn_result = FALSE;
+
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) {
+ /* Multi-class learning */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (autolearn_class) {
+ learn_result = rspamd_stat_learn_class(task, autolearn_class,
+ task->cfg->lua_state, task->classifier,
+ st, &stat_error);
+ }
+ else {
+ g_set_error(&stat_error, g_quark_from_static_string("stat"), 500,
+ "No autolearn class specified for multi-class learning");
+ }
+ }
+ else {
+ /* Legacy binary learning */
+ learn_result = rspamd_stat_learn(task,
+ task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM,
+ task->cfg->lua_state, task->classifier,
+ st, &stat_error);
+ }
+
+ if (!learn_result) {
if (stat_error == NULL) {
g_set_error(&stat_error,
#define RSPAMD_TASK_FLAG_LEARN_SPAM (1u << 12u)
#define RSPAMD_TASK_FLAG_LEARN_HAM (1u << 13u)
#define RSPAMD_TASK_FLAG_LEARN_AUTO (1u << 14u)
+#define RSPAMD_TASK_FLAG_LEARN_CLASS (1u << 25u)
#define RSPAMD_TASK_FLAG_BROKEN_HEADERS (1u << 15u)
-#define RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS (1u << 16u)
-#define RSPAMD_TASK_FLAG_HAS_HAM_TOKENS (1u << 17u)
+/* Removed RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS and RSPAMD_TASK_FLAG_HAS_HAM_TOKENS - not needed in multi-class */
#define RSPAMD_TASK_FLAG_EMPTY (1u << 18u)
#define RSPAMD_TASK_FLAG_PROFILE (1u << 19u)
#define RSPAMD_TASK_FLAG_GREYLISTED (1u << 20u)
#define RSPAMD_TASK_FLAG_SSL (1u << 22u)
#define RSPAMD_TASK_FLAG_BAD_UNICODE (1u << 23u)
#define RSPAMD_TASK_FLAG_MESSAGE_REWRITE (1u << 24u)
-#define RSPAMD_TASK_FLAG_MAX_SHIFT (24u)
+#define RSPAMD_TASK_FLAG_MAX_SHIFT (25u)
/* Request has been done by a local client */
#define RSPAMD_TASK_PROTOCOL_FLAG_LOCAL_CLIENT (1u << 1u)
gpointer runtime)
{
auto *cdbp = CDB_FROM_RAW(runtime);
- bool seen_values = false;
for (auto i = 0u; i < tokens->len; i++) {
rspamd_token_t *tok;
if (res) {
tok->values[id] = res.value();
- seen_values = true;
}
else {
tok->values[id] = 0;
}
}
- if (seen_values) {
- if (cdbp->is_spam()) {
- task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
- }
- else {
- task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
- }
- }
+ /* No longer need to set flags - multi-class handles missing data naturally */
return true;
}
{
auto *cdbp = CDB_FROM_RAW(ctx);
delete cdbp;
-}
\ No newline at end of file
+}
#define RSPAMD_STATFILE_VERSION \
{ \
- '1', '2' \
- }
+ '1', '2'}
#define BACKUP_SUFFIX ".old"
static void rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool,
tok->values[id] = rspamd_mmaped_file_get_block(mf, h1, h2);
}
- if (mf->cf->is_spam) {
- task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
- }
- else {
- task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
- }
+ /* No longer need to set flags - multi-class handles missing data naturally */
return TRUE;
}
}
static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded,
- bool is_spam) -> std::optional<redis_stat_runtime<T> *>
+ const char *class_label) -> std::optional<redis_stat_runtime<T> *>
{
- auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ auto var_name = fmt::format("{}_{}", redis_object_expanded, class_label);
auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str());
if (res) {
return true;
}
- auto save_in_mempool(bool is_spam) const
+ auto save_in_mempool(const char *class_label) const
{
- auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ auto var_name = fmt::format("{}_{}", redis_object_expanded, class_label);
/* We do not set destructor for the variable, as it should be already added on creation */
rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr);
msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str());
return g_quark_from_static_string(M);
}
+/*
+ * Get the class label for a statfile (for multi-class support)
+ */
+static const char *
+get_class_label(struct rspamd_statfile_config *stcf)
+{
+ /* Try to get the label from the classifier config first */
+ if (stcf->clcf && stcf->clcf->class_labels && stcf->class_name) {
+ const char *label = rspamd_config_get_class_label(stcf->clcf, stcf->class_name);
+ if (label) {
+ return label;
+ }
+ /* If no label mapping found, use class name directly */
+ return stcf->class_name;
+ }
+
+ /* Fallback to legacy binary classification */
+ return stcf->is_spam ? "S" : "H";
+}
+
/*
* Non-static for lua unit testing
*/
ucl_object_push_lua(L, st->classifier->cfg->opts, false);
ucl_object_push_lua(L, st->stcf->opts, false);
lua_pushstring(L, backend->stcf->symbol);
- lua_pushboolean(L, backend->stcf->is_spam);
+ lua_pushstring(L, get_class_label(backend->stcf)); /* Pass class label instead of boolean */
/* Push event loop if there is one available (e.g. we are not in rspamadm mode) */
if (ctx->event_loop) {
return nullptr;
}
+ const char *class_label = get_class_label(stcf);
+
/* Look for the cached results */
if (!learn) {
auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- object_expanded, stcf->is_spam);
+ object_expanded, class_label);
if (maybe_existing) {
auto *rt = maybe_existing.value();
/* No cached result (or learn), create new one */
auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
- if (!learn) {
+ if (!learn && stcf->clcf && stcf->clcf->class_names && stcf->clcf->class_names->len > 2) {
/*
- * For check, we also need to create the opposite class runtime to avoid
- * double call for Redis scripts.
- * This runtime will be filled later.
+ * For multi-class classification, we need to create runtimes for ALL classes
+ * to avoid multiple Redis calls. The actual Redis call will fetch data for all classes.
*/
+ GList *cur = stcf->clcf->statfiles;
+ while (cur) {
+ auto *other_stcf = (struct rspamd_statfile_config *) cur->data;
+ if (other_stcf != stcf) {
+ const char *other_label = get_class_label(other_stcf);
+
+ auto maybe_other_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ object_expanded, other_label);
+ if (!maybe_other_rt) {
+ auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+ other_rt->save_in_mempool(other_label);
+ other_rt->need_redis_call = false;
+ }
+ }
+ cur = g_list_next(cur);
+ }
+ }
+ else if (!learn) {
+ /*
+ * For binary classification, create the opposite class runtime to avoid
+ * double call for Redis scripts (backward compatibility).
+ */
+ const char *opposite_label = stcf->is_spam ? "H" : "S";
auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- object_expanded,
- !stcf->is_spam);
+ object_expanded, opposite_label);
if (!maybe_opposite_rt) {
auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
- opposite_rt->save_in_mempool(!stcf->is_spam);
+ opposite_rt->save_in_mempool(opposite_label);
opposite_rt->need_redis_call = false;
}
}
- rt->save_in_mempool(stcf->is_spam);
+ rt->save_in_mempool(class_label);
return rt;
}
bool result = lua_toboolean(L, 2);
if (result) {
- /* Indexes:
- * 3 - learned_ham (int)
- * 4 - learned_spam (int)
- * 5 - ham_tokens (pair<int, int>)
- * 6 - spam_tokens (pair<int, int>)
+ /* Check if this is binary format [learned_ham, learned_spam, ham_tokens, spam_tokens]
+ * or multi-class format [learned_counts_table, outputs_table]
*/
- /*
- * We need to fill our runtime AND the opposite runtime
- */
auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
rt->learned = learned;
redis_stat_runtime<float>::result_type *res;
rt->set_results(res);
};
- auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- rt->redis_object_expanded,
- !rt->stcf->is_spam);
+ /* Check if result[3] is a number (binary) or table (multi-class) */
+ lua_rawgeti(L, 3, 1); /* Get first element of result array */
+ bool is_binary_format = lua_isnumber(L, -1);
+ lua_pop(L, 1);
- if (!opposite_rt_maybe) {
- msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
+ if (is_binary_format) {
+ /* Binary format: [learned_ham, learned_spam, ham_tokens, spam_tokens] */
- return 0;
- }
+ /* Find the opposite runtime for binary classification compatibility */
+ const char *opposite_label;
+ if (rt->stcf->class_name) {
+ /* Multi-class: find a different class (simplified for now) */
+ opposite_label = strcmp(get_class_label(rt->stcf), "S") == 0 ? "H" : "S";
+ }
+ else {
+ /* Binary: use opposite spam/ham */
+ opposite_label = rt->stcf->is_spam ? "H" : "S";
+ }
+ auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ opposite_label);
+
+ if (!opposite_rt_maybe) {
+ msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
+ return 0;
+ }
+
+ if (rt->stcf->is_spam) {
+ filler_func(rt, L, lua_tointeger(L, 4), 6);
+ filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+ }
+ else {
+ filler_func(rt, L, lua_tointeger(L, 3), 5);
+ filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
+ }
- if (rt->stcf->is_spam) {
- filler_func(rt, L, lua_tointeger(L, 4), 6);
- filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+ /* Process all tokens */
+ g_assert(rt->tokens != nullptr);
+ rt->process_tokens(rt->tokens);
+ opposite_rt_maybe.value()->process_tokens(rt->tokens);
}
else {
- filler_func(rt, L, lua_tointeger(L, 3), 5);
- filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
- }
+ /* Multi-class format: [learned_counts_table, outputs_table] */
+
+ /* Get learned counts table (index 3) and outputs table (index 4) */
+ lua_rawgeti(L, 3, 1); /* learned_counts */
+ lua_rawgeti(L, 3, 2); /* outputs */
+
+ /* Iterate through all class labels to fill all runtimes */
+ if (rt->stcf->clcf && rt->stcf->clcf->class_labels) {
+ GHashTableIter iter;
+ gpointer key, value;
+ g_hash_table_iter_init(&iter, rt->stcf->clcf->class_labels);
+
+ while (g_hash_table_iter_next(&iter, &key, &value)) {
+ const char *class_label = (const char *) value;
+
+ /* Find runtime for this class */
+ auto class_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ class_label);
+
+ if (class_rt_maybe) {
+ auto *class_rt = class_rt_maybe.value();
+
+ /* Get learned count for this class */
+ lua_pushstring(L, class_label);
+ lua_gettable(L, -3); /* learned_counts[class_label] */
+ unsigned learned = lua_tointeger(L, -1);
+ lua_pop(L, 1);
+
+ /* Get outputs for this class */
+ lua_pushstring(L, class_label);
+ lua_gettable(L, -2); /* outputs[class_label] */
+ int outputs_pos = lua_gettop(L);
+
+ filler_func(class_rt, L, learned, outputs_pos);
+ lua_pop(L, 1);
+ }
+ }
+ }
+
+ lua_pop(L, 2); /* Pop learned_counts and outputs tables */
- /* Mark task as being processed */
- task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+ /* Process tokens for all runtimes */
+ g_assert(rt->tokens != nullptr);
+ rt->process_tokens(rt->tokens);
+ }
- /* Process all tokens */
- g_assert(rt->tokens != nullptr);
- rt->process_tokens(rt->tokens);
- opposite_rt_maybe.value()->process_tokens(rt->tokens);
+ /* Tokens processed - no need to set flags in multi-class approach */
}
else {
/* Error message is on index 3 */
rspamd_lua_task_push(L, task);
lua_pushstring(L, rt->redis_object_expanded);
lua_pushinteger(L, id);
- lua_pushboolean(L, rt->stcf->is_spam);
+
+ /* Send all class labels for multi-class support */
+ if (rt->stcf->clcf && rt->stcf->clcf->class_labels && g_hash_table_size(rt->stcf->clcf->class_labels) > 0) {
+ /* Multi-class: send array of all class labels */
+ lua_newtable(L);
+ GHashTableIter iter;
+ gpointer key, value;
+ int idx = 1;
+ g_hash_table_iter_init(&iter, rt->stcf->clcf->class_labels);
+ while (g_hash_table_iter_next(&iter, &key, &value)) {
+ lua_pushstring(L, (const char *) value); /* Use the label, not class name */
+ lua_rawseti(L, -2, idx++);
+ }
+ }
+ else {
+ /* Binary compatibility: send current class label as single string */
+ lua_pushstring(L, get_class_label(rt->stcf));
+ }
+
lua_new_text(L, tokens_buf, tokens_len, false);
/* Store rt in random cookie */
rspamd_lua_task_push(L, task);
lua_pushstring(L, rt->redis_object_expanded);
lua_pushinteger(L, id);
- lua_pushboolean(L, rt->stcf->is_spam);
+ lua_pushstring(L, get_class_label(rt->stcf)); /* Pass class label instead of boolean */
lua_pushstring(L, rt->stcf->symbol);
/* Detect unlearn */
}
}
- if (rt->cf->is_spam) {
- task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
- }
- else {
- task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
- }
+ /* No longer need to set flags - multi-class handles missing data naturally */
}
}
struct bayes_task_closure {
- double ham_prob;
- double spam_prob;
+ double ham_prob; /* Kept for binary compatibility */
+ double spam_prob; /* Kept for binary compatibility */
double meta_skip_prob;
uint64_t processed_tokens;
uint64_t total_hits;
struct rspamd_task *task;
};
+/* Multi-class classification closure */
+struct bayes_multiclass_closure {
+ double *class_log_probs; /* Array of log probabilities for each class */
+ uint64_t *class_learns; /* Learning counts for each class */
+ char **class_names; /* Array of class names */
+ unsigned int num_classes; /* Number of classes */
+ double meta_skip_prob;
+ uint64_t processed_tokens;
+ uint64_t total_hits;
+ uint64_t text_tokens;
+ struct rspamd_task *task;
+ struct rspamd_classifier_config *cfg;
+};
+
/*
* Mathematically we use pow(complexity, complexity), where complexity is the
* window index
}
}
+/*
+ * Multinomial token classification for multi-class Bayes
+ */
+static void
+bayes_classify_token_multiclass(struct rspamd_classifier *ctx,
+ rspamd_token_t *tok,
+ struct bayes_multiclass_closure *cl)
+{
+ unsigned int i, j;
+ int id;
+ struct rspamd_statfile *st;
+ struct rspamd_task *task;
+ const char *token_type = "txt";
+ double val, fw, w;
+ unsigned int *class_counts;
+ unsigned int total_count = 0;
+
+ task = cl->task;
+
+ /* Skip meta tokens probabilistically if configured */
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) {
+ val = rspamd_random_double_fast();
+ if (val <= cl->meta_skip_prob) {
+ return;
+ }
+ token_type = "meta";
+ }
+
+ /* Allocate array for class counts */
+ class_counts = g_alloca(cl->num_classes * sizeof(unsigned int));
+ memset(class_counts, 0, cl->num_classes * sizeof(unsigned int));
+
+ /* Collect counts for each class */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index(ctx->statfiles_ids, int, i);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ g_assert(st != NULL);
+ val = tok->values[id];
+
+ if (val > 0) {
+ /* Find which class this statfile belongs to */
+ for (j = 0; j < cl->num_classes; j++) {
+ if (st->stcf->class_name &&
+ strcmp(st->stcf->class_name, cl->class_names[j]) == 0) {
+ class_counts[j] += val;
+ total_count += val;
+ cl->total_hits += val;
+ break;
+ }
+ }
+ }
+ }
+
+ /* Calculate multinomial probability for this token */
+ if (total_count >= ctx->cfg->min_token_hits) {
+ /* Feature weight calculation */
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) {
+ fw = 1.0;
+ }
+ else {
+ fw = feature_weight[tok->window_idx % G_N_ELEMENTS(feature_weight)];
+ }
+
+ w = (fw * total_count) / (1.0 + fw * total_count);
+
+ /* Apply multinomial model for each class */
+ for (j = 0; j < cl->num_classes; j++) {
+ double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]);
+ double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes);
+
+ /* Skip probabilities too close to uniform (1/num_classes) */
+ double uniform_prior = 1.0 / cl->num_classes;
+ if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) {
+ continue;
+ }
+
+ cl->class_log_probs[j] += log(class_prob);
+ }
+
+ cl->processed_tokens++;
+ if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
+ cl->text_tokens++;
+ }
+
+ if (tok->t1 && tok->t2) {
+ msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %.3f, total_count: %ud, "
+ "processed for %u classes",
+ token_type, tok->data,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ fw, total_count, cl->num_classes);
+ }
+ }
+}
+
+/*
+ * Multinomial Bayes classification with Fisher confidence
+ */
+static gboolean
+bayes_classify_multiclass(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task)
+{
+ struct bayes_multiclass_closure cl;
+ rspamd_token_t *tok;
+ unsigned int i, j, text_tokens = 0;
+ int id;
+ struct rspamd_statfile *st;
+ rspamd_multiclass_result_t *result;
+ double *normalized_probs;
+ double max_log_prob = -INFINITY;
+ unsigned int winning_class_idx = 0;
+ double confidence;
+
+ g_assert(ctx != NULL);
+ g_assert(tokens != NULL);
+
+ /* Initialize multi-class closure */
+ memset(&cl, 0, sizeof(cl));
+ cl.task = task;
+ cl.cfg = ctx->cfg;
+
+ /* Get class information from classifier config */
+ if (!ctx->cfg->class_names || ctx->cfg->class_names->len < 2) {
+ msg_debug_bayes("insufficient classes for multiclass classification");
+ return TRUE; /* Fall back to binary mode */
+ }
+
+ cl.num_classes = ctx->cfg->class_names->len;
+ cl.class_names = (char **) ctx->cfg->class_names->pdata;
+ cl.class_log_probs = g_alloca(cl.num_classes * sizeof(double));
+ cl.class_learns = g_alloca(cl.num_classes * sizeof(uint64_t));
+
+ /* Initialize probabilities and get learning counts */
+ for (i = 0; i < cl.num_classes; i++) {
+ cl.class_log_probs[i] = 0.0;
+ cl.class_learns[i] = 0;
+ }
+
+ /* Collect learning counts for each class */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index(ctx->statfiles_ids, int, i);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ g_assert(st != NULL);
+
+ for (j = 0; j < cl.num_classes; j++) {
+ if (st->stcf->class_name &&
+ strcmp(st->stcf->class_name, cl.class_names[j]) == 0) {
+ cl.class_learns[j] += st->backend->total_learns(task,
+ g_ptr_array_index(task->stat_runtimes, id), ctx->ctx);
+ break;
+ }
+ }
+ }
+
+ /* Check minimum learns requirement */
+ if (ctx->cfg->min_learns > 0) {
+ for (i = 0; i < cl.num_classes; i++) {
+ if (cl.class_learns[i] < ctx->cfg->min_learns) {
+ msg_info_task("not classified as %s. The class needs more "
+ "training samples. Currently: %ul; minimum %ud required",
+ cl.class_names[i], cl.class_learns[i], ctx->cfg->min_learns);
+ return TRUE;
+ }
+ }
+ }
+
+ /* Count text tokens */
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
+ text_tokens++;
+ }
+ }
+
+ if (text_tokens == 0) {
+ msg_info_task("skipped classification as there are no text tokens. "
+ "Total tokens: %ud",
+ tokens->len);
+ return TRUE;
+ }
+
+ /* Set meta token skip probability */
+ if (text_tokens > tokens->len - text_tokens) {
+ cl.meta_skip_prob = 0.0;
+ }
+ else {
+ cl.meta_skip_prob = 1.0 - (double) text_tokens / tokens->len;
+ }
+
+ /* Process all tokens */
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ bayes_classify_token_multiclass(ctx, tok, &cl);
+ }
+
+ if (cl.processed_tokens == 0) {
+ msg_info_bayes("no tokens found in bayes database "
+ "(%ud total tokens, %ud text tokens), ignore stats",
+ tokens->len, text_tokens);
+ return TRUE;
+ }
+
+ if (ctx->cfg->min_tokens > 0 &&
+ cl.text_tokens < (int) (ctx->cfg->min_tokens * 0.1)) {
+ msg_info_bayes("ignore bayes probability since we have "
+ "found too few text tokens: %uL (of %ud checked), "
+ "at least %d required",
+ cl.text_tokens, text_tokens,
+ (int) (ctx->cfg->min_tokens * 0.1));
+ return TRUE;
+ }
+
+ /* Normalize probabilities using softmax */
+ normalized_probs = g_alloca(cl.num_classes * sizeof(double));
+
+ /* Find maximum for numerical stability */
+ for (i = 0; i < cl.num_classes; i++) {
+ if (cl.class_log_probs[i] > max_log_prob) {
+ max_log_prob = cl.class_log_probs[i];
+ winning_class_idx = i;
+ }
+ }
+
+ /* Apply softmax normalization */
+ double sum_exp = 0.0;
+ for (i = 0; i < cl.num_classes; i++) {
+ normalized_probs[i] = exp(cl.class_log_probs[i] - max_log_prob);
+ sum_exp += normalized_probs[i];
+ }
+
+ if (sum_exp > 0) {
+ for (i = 0; i < cl.num_classes; i++) {
+ normalized_probs[i] /= sum_exp;
+ }
+ }
+ else {
+ /* Fallback to uniform distribution */
+ for (i = 0; i < cl.num_classes; i++) {
+ normalized_probs[i] = 1.0 / cl.num_classes;
+ }
+ }
+
+ /* Calculate confidence using Fisher method for the winning class */
+ if (max_log_prob > -300) {
+ confidence = 1.0 - inv_chi_square(task, max_log_prob, cl.processed_tokens);
+ }
+ else {
+ confidence = normalized_probs[winning_class_idx];
+ }
+
+ /* Create and store multiclass result */
+ result = g_new0(rspamd_multiclass_result_t, 1);
+ result->class_names = g_new(char *, cl.num_classes);
+ result->probabilities = g_new(double, cl.num_classes);
+ result->num_classes = cl.num_classes;
+ result->winning_class = cl.class_names[winning_class_idx]; /* Reference, not copy */
+ result->confidence = confidence;
+
+ for (i = 0; i < cl.num_classes; i++) {
+ result->class_names[i] = g_strdup(cl.class_names[i]);
+ result->probabilities[i] = normalized_probs[i];
+ }
+
+ rspamd_task_set_multiclass_result(task, result);
+
+ /* Insert symbol for winning class if confidence is significant */
+ if (confidence > 0.05) {
+ char sumbuf[32];
+ double final_prob = rspamd_normalize_probability(confidence, 0.5);
+
+ rspamd_snprintf(sumbuf, sizeof(sumbuf), "%.2f%%", confidence * 100.0);
+
+ /* Find the statfile for the winning class to get the symbol */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index(ctx->statfiles_ids, int, i);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+
+ if (st->stcf->class_name &&
+ strcmp(st->stcf->class_name, cl.class_names[winning_class_idx]) == 0) {
+ rspamd_task_insert_result(task, st->stcf->symbol, final_prob, sumbuf);
+ break;
+ }
+ }
+
+ msg_debug_bayes("multiclass classification: winning class '%s' with "
+ "probability %.3f, confidence %.3f, %uL tokens processed",
+ cl.class_names[winning_class_idx],
+ normalized_probs[winning_class_idx],
+ confidence, cl.processed_tokens);
+ }
+
+ return TRUE;
+}
+
gboolean
bayes_init(struct rspamd_config *cfg,
g_assert(ctx != NULL);
g_assert(tokens != NULL);
+ /* Check if this is a multi-class classifier */
+ if (ctx->cfg->class_names && ctx->cfg->class_names->len >= 2) {
+ /* Verify that at least one statfile has class_name set (indicating new multi-class config) */
+ gboolean has_class_names = FALSE;
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ int id = g_array_index(ctx->statfiles_ids, int, i);
+ struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ if (st->stcf->class_name) {
+ has_class_names = TRUE;
+ break;
+ }
+ }
+
+ if (has_class_names) {
+ msg_debug_bayes("using multiclass classification with %u classes",
+ (unsigned int) ctx->cfg->class_names->len);
+ return bayes_classify_multiclass(ctx, tokens, task);
+ }
+ }
+
+ /* Fall back to binary classification */
+ msg_debug_bayes("using binary classification");
memset(&cl, 0, sizeof(cl));
cl.task = task;
return TRUE;
}
+
+gboolean
+bayes_learn_class(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ const char *class_name,
+ gboolean unlearn,
+ GError **err)
+{
+ unsigned int i, j, total_cnt;
+ int id;
+ struct rspamd_statfile *st;
+ rspamd_token_t *tok;
+ gboolean incrementing;
+ unsigned int *class_counts = NULL;
+ struct rspamd_statfile **class_statfiles = NULL;
+ unsigned int num_classes = 0;
+
+ g_assert(ctx != NULL);
+ g_assert(tokens != NULL);
+ g_assert(class_name != NULL);
+
+ incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
+
+ /* Count classes and prepare arrays for multi-class learning */
+ if (ctx->cfg->class_names && ctx->cfg->class_names->len > 0) {
+ num_classes = ctx->cfg->class_names->len;
+ class_counts = g_alloca(num_classes * sizeof(unsigned int));
+ class_statfiles = g_alloca(num_classes * sizeof(struct rspamd_statfile *));
+ memset(class_counts, 0, num_classes * sizeof(unsigned int));
+ memset(class_statfiles, 0, num_classes * sizeof(struct rspamd_statfile *));
+ }
+
+ for (i = 0; i < tokens->len; i++) {
+ total_cnt = 0;
+ tok = g_ptr_array_index(tokens, i);
+
+ /* Reset class counts for this token */
+ if (num_classes > 0) {
+ memset(class_counts, 0, num_classes * sizeof(unsigned int));
+ }
+
+ for (j = 0; j < ctx->statfiles_ids->len; j++) {
+ id = g_array_index(ctx->statfiles_ids, int, j);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ g_assert(st != NULL);
+
+ /* Determine if this statfile matches our target class */
+ gboolean is_target_class = FALSE;
+ if (st->stcf->class_name) {
+ /* Multi-class: exact class name match */
+ is_target_class = (strcmp(st->stcf->class_name, class_name) == 0);
+ }
+ else {
+ /* Legacy binary: map class_name to spam/ham */
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ is_target_class = st->stcf->is_spam;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ is_target_class = !st->stcf->is_spam;
+ }
+ }
+
+ if (is_target_class) {
+ /* Learning: increment the target class */
+ if (incrementing) {
+ tok->values[id] = 1;
+ }
+ else {
+ tok->values[id]++;
+ }
+ total_cnt += tok->values[id];
+
+ /* Track class counts for debugging */
+ if (num_classes > 0) {
+ for (unsigned int k = 0; k < num_classes; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+ if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) {
+ class_counts[k] += tok->values[id];
+ class_statfiles[k] = st;
+ break;
+ }
+ }
+ }
+ }
+ else {
+ /* Unlearning: decrement other classes if unlearn flag is set */
+ if (tok->values[id] > 0 && unlearn) {
+ if (incrementing) {
+ tok->values[id] = -1;
+ }
+ else {
+ tok->values[id]--;
+ }
+ total_cnt += tok->values[id];
+
+ /* Track class counts for debugging */
+ if (num_classes > 0) {
+ for (unsigned int k = 0; k < num_classes; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+ if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) {
+ class_counts[k] += tok->values[id];
+ class_statfiles[k] = st;
+ break;
+ }
+ }
+ }
+ }
+ else if (incrementing) {
+ tok->values[id] = 0;
+ }
+ }
+ }
+
+ /* Debug logging */
+ if (tok->t1 && tok->t2) {
+ if (num_classes > 0) {
+ GString *debug_str = g_string_new("");
+ for (unsigned int k = 0; k < num_classes; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+ g_string_append_printf(debug_str, "%s:%d ", check_class, class_counts[k]);
+ }
+ msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, "
+ "class_counts: %s",
+ tok->data,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ tok->window_idx, total_cnt, debug_str->str);
+ g_string_free(debug_str, TRUE);
+ }
+ else {
+ msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, "
+ "class: %s",
+ tok->data,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ tok->window_idx, total_cnt, class_name);
+ }
+ }
+ else {
+ msg_debug_bayes("token %uL <?:?>: window: %d, total_count: %d, "
+ "class: %s",
+ tok->data,
+ tok->window_idx, total_cnt, class_name);
+ }
+ }
+
+ return TRUE;
+}
gboolean unlearn,
GError **err);
+ gboolean (*learn_class_func)(struct rspamd_classifier *ctx,
+ GPtrArray *input,
+ struct rspamd_task *task,
+ const char *class_name,
+ gboolean unlearn,
+ GError **err);
+
void (*fin_func)(struct rspamd_classifier *cl);
};
gboolean unlearn,
GError **err);
+gboolean bayes_learn_class(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ const char *class_name,
+ gboolean unlearn,
+ GError **err);
+
void bayes_fin(struct rspamd_classifier *);
/* Generic lua classifier */
unsigned int stage,
GError **err);
+/**
+ * Learn task as a specific class, task must be processed prior to this call
+ * @param task task to learn
+ * @param class_name name of the class to learn (e.g., "spam", "ham", "transactional")
+ * @param L lua state
+ * @param classifier NULL to learn all classifiers, name to learn a specific one
+ * @param stage learning stage
+ * @param err error returned
+ * @return TRUE if task has been learned
+ */
+rspamd_stat_result_t rspamd_stat_learn_class(struct rspamd_task *task,
+ const char *class_name,
+ lua_State *L,
+ const char *classifier,
+ unsigned int stage,
+ GError **err);
+
/**
* Get the overall statistics for all statfile backends
* @param cfg configuration
void rspamd_stat_unload(void);
+/**
+ * Multi-class classification result structure
+ */
+typedef struct {
+ char **class_names; /**< Array of class names */
+ double *probabilities; /**< Array of probabilities for each class */
+ unsigned int num_classes; /**< Number of classes */
+ const char *winning_class; /**< Name of the winning class (reference, not owned) */
+ double confidence; /**< Confidence of the winning class */
+} rspamd_multiclass_result_t;
+
+/**
+ * Set multi-class classification result for a task
+ */
+void rspamd_task_set_multiclass_result(struct rspamd_task *task,
+ rspamd_multiclass_result_t *result);
+
+/**
+ * Get multi-class classification result from a task
+ */
+rspamd_multiclass_result_t *rspamd_task_get_multiclass_result(struct rspamd_task *task);
+
+/**
+ * Free multi-class result structure
+ */
+void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result);
+
+/**
+ * Set autolearn class for a task
+ */
+void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name);
+
+/**
+ * Get autolearn class from a task
+ */
+const char *rspamd_task_get_autolearn_class(struct rspamd_task *task);
+
#ifdef __cplusplus
}
#endif
.init_func = lua_classifier_init,
.classify_func = lua_classifier_classify,
.learn_spam_func = lua_classifier_learn_spam,
+ .learn_class_func = NULL, /* TODO: implement lua multi-class learning */
.fin_func = NULL,
};
.init_func = bayes_init,
.classify_func = bayes_classify,
.learn_spam_func = bayes_learn_spam,
+ .learn_class_func = bayes_learn_class,
.fin_func = bayes_fin,
}};
.dec_learns = rspamd_##eltn##_dec_learns, \
.get_stat = rspamd_##eltn##_get_stat, \
.load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \
- .close = rspamd_##eltn##_close \
- }
+ .close = rspamd_##eltn##_close}
#define RSPAMD_STAT_BACKEND_ELT_READONLY(nam, eltn) \
{ \
.name = #nam, \
.dec_learns = NULL, \
.get_stat = rspamd_##eltn##_get_stat, \
.load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \
- .close = rspamd_##eltn##_close \
- }
+ .close = rspamd_##eltn##_close}
static struct rspamd_stat_backend stat_backends[] = {
RSPAMD_STAT_BACKEND_ELT(mmap, mmaped_file),
.runtime = rspamd_stat_cache_##eltn##_runtime, \
.check = rspamd_stat_cache_##eltn##_check, \
.learn = rspamd_stat_cache_##eltn##_learn, \
- .close = rspamd_stat_cache_##eltn##_close \
- }
+ .close = rspamd_stat_cache_##eltn##_close}
static struct rspamd_stat_cache stat_caches[] = {
RSPAMD_STAT_CACHE_ELT(sqlite3, sqlite3),
static const double similarity_threshold = 80.0;
+void rspamd_task_set_multiclass_result(struct rspamd_task *task, rspamd_multiclass_result_t *result)
+{
+ g_assert(task != NULL);
+ g_assert(result != NULL);
+
+ rspamd_mempool_set_variable(task->task_pool, "multiclass_bayes_result", result,
+ (rspamd_mempool_destruct_t) rspamd_multiclass_result_free);
+}
+
+rspamd_multiclass_result_t *
+rspamd_task_get_multiclass_result(struct rspamd_task *task)
+{
+ g_assert(task != NULL);
+
+ return (rspamd_multiclass_result_t *) rspamd_mempool_get_variable(task->task_pool,
+ "multiclass_bayes_result");
+}
+
+void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result)
+{
+ if (result == NULL) {
+ return;
+ }
+
+ g_free(result->class_names);
+ g_free(result->probabilities);
+ /* winning_class is a reference, not owned - don't free */
+ g_free(result);
+}
+
+void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name)
+{
+ g_assert(task != NULL);
+ g_assert(class_name != NULL);
+
+ /* Store the class name in the mempool */
+ const char *class_name_copy = rspamd_mempool_strdup(task->task_pool, class_name);
+ rspamd_mempool_set_variable(task->task_pool, "autolearn_class",
+ (gpointer) class_name_copy, NULL);
+
+ /* Set the appropriate flags */
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_CLASS;
+
+ /* For backward compatibility, also set binary flags */
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ }
+}
+
+const char *
+rspamd_task_get_autolearn_class(struct rspamd_task *task)
+{
+ g_assert(task != NULL);
+
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) {
+ return (const char *) rspamd_mempool_get_variable(task->task_pool, "autolearn_class");
+ }
+
+ /* Fallback to binary flags for backward compatibility */
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+ return "spam";
+ }
+ else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+ return "ham";
+ }
+
+ return NULL;
+}
+
static void
rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx,
struct rspamd_task *task)
}
/*
- * Do not classify a message if some class is missing
+ * Multi-class approach: don't check for missing classes
+ * Missing tokens naturally result in 0 probability
*/
- if (!(task->flags & RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS)) {
- msg_info_task("skip statistics as SPAM class is missing");
-
- return;
- }
- if (!(task->flags & RSPAMD_TASK_FLAG_HAS_HAM_TOKENS)) {
- msg_info_task("skip statistics as HAM class is missing");
-
- return;
- }
for (i = 0; i < st_ctx->classifiers->len; i++) {
cl = g_ptr_array_index(st_ctx->classifiers, i);
if (sel->cache && sel->cachecf) {
rt = cl->cache->runtime(task, sel->cachecf, FALSE);
- learn_res = cl->cache->check(task, spam, rt);
+
+ /* For multi-class learning, determine spam boolean from class name if available */
+ gboolean cache_spam = spam; /* Default to original spam parameter */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (autolearn_class) {
+ if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) {
+ cache_spam = TRUE;
+ }
+ else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) {
+ cache_spam = FALSE;
+ }
+ else {
+ /* For other classes, use a heuristic or default to spam for cache purposes */
+ cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */
+ }
+ }
+
+ learn_res = cl->cache->check(task, cache_spam, rt);
}
if (learn_res == RSPAMD_LEARN_IGNORE) {
continue;
}
- if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam,
- task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
- learned = TRUE;
+ /* Check if classifier supports multi-class learning and if we should use it */
+ if (cl->subrs->learn_class_func && cl->cfg->class_names && cl->cfg->class_names->len > 2) {
+ /* Multi-class learning: determine class name from task flags or autolearn result */
+ const char *class_name = NULL;
+
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+ /* Find spam class name */
+ for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k);
+ /* Look for statfile with this class that is spam */
+ GList *cur = cl->cfg->statfiles;
+ while (cur) {
+ struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && stcf->is_spam) {
+ class_name = check_class;
+ break;
+ }
+ cur = g_list_next(cur);
+ }
+ if (class_name) break;
+ }
+ if (!class_name) class_name = "spam"; /* fallback */
+ }
+ else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+ /* Find ham class name */
+ for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k);
+ /* Look for statfile with this class that is ham */
+ GList *cur = cl->cfg->statfiles;
+ while (cur) {
+ struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && !stcf->is_spam) {
+ class_name = check_class;
+ break;
+ }
+ cur = g_list_next(cur);
+ }
+ if (class_name) break;
+ }
+ if (!class_name) class_name = "ham"; /* fallback */
+ }
+ else {
+ /* Fallback to spam/ham based on the spam parameter */
+ class_name = spam ? "spam" : "ham";
+ }
+
+ if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
+ }
+ else {
+ /* Binary learning: use existing function */
+ if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
}
}
if (cl->cache) {
cache_run = cl->cache->runtime(task, cl->cachecf, TRUE);
- cl->cache->learn(task, spam, cache_run);
+
+ /* For multi-class learning, determine spam boolean from class name if available */
+ gboolean cache_spam = spam; /* Default to original spam parameter */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (autolearn_class) {
+ if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) {
+ cache_spam = TRUE;
+ }
+ else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) {
+ cache_spam = FALSE;
+ }
+ else {
+ /* For other classes, use a heuristic or default to spam for cache purposes */
+ cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */
+ }
+ }
+
+ cl->cache->learn(task, cache_spam, cache_run);
}
}
return res;
}
+static gboolean
+rspamd_stat_classifiers_learn_class(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const char *classifier,
+ const char *class_name,
+ GError **err)
+{
+ struct rspamd_classifier *cl, *sel = NULL;
+ unsigned int i;
+ gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
+
+ if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
+ *err == NULL) {
+ /* Do not learn twice */
+ g_set_error(err, rspamd_stat_quark(), 208, "<%s> has been already "
+ "learned as %s, ignore it",
+ MESSAGE_FIELD(task, message_id),
+ class_name);
+
+ return FALSE;
+ }
+
+ /* Check whether we have learned that file */
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
+ continue;
+ }
+
+ sel = cl;
+
+ /* Now check max and min tokens */
+ if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
+ msg_info_task(
+ "<%s> contains less tokens than required for %s classifier: "
+ "%ud < %ud",
+ MESSAGE_FIELD(task, message_id),
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->min_tokens);
+ too_small = TRUE;
+ continue;
+ }
+ else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
+ msg_info_task(
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%ud > %ud",
+ MESSAGE_FIELD(task, message_id),
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->max_tokens);
+ too_large = TRUE;
+ continue;
+ }
+
+ /* Use the new multi-class learning function if available */
+ if (cl->subrs->learn_class_func) {
+ if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
+ }
+ else {
+ /* Fallback to binary learning with class name mapping */
+ gboolean is_spam;
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ is_spam = TRUE;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ is_spam = FALSE;
+ }
+ else {
+ /* For unknown classes with binary classifier, skip */
+ msg_info_task("skipping class '%s' for binary classifier %s",
+ class_name, cl->cfg->name);
+ continue;
+ }
+
+ if (cl->subrs->learn_spam_func(cl, task->tokens, task, is_spam,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
+ }
+ }
+
+ if (sel == NULL) {
+ if (classifier) {
+ g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
+ "with name %s",
+ classifier);
+ }
+ else {
+ g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
+ }
+
+ return FALSE;
+ }
+
+ if (!learned && err && *err == NULL) {
+ if (too_large) {
+ g_set_error(err, rspamd_stat_quark(), 204,
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%d > %d",
+ MESSAGE_FIELD(task, message_id),
+ sel->cfg->name,
+ task->tokens->len,
+ sel->cfg->max_tokens);
+ }
+ else if (too_small) {
+ g_set_error(err, rspamd_stat_quark(), 204,
+ "<%s> contains less tokens than required for %s classifier: "
+ "%d < %d",
+ MESSAGE_FIELD(task, message_id),
+ sel->cfg->name,
+ task->tokens->len,
+ sel->cfg->min_tokens);
+ }
+ }
+
+ return learned;
+}
+
+rspamd_stat_result_t
+rspamd_stat_learn_class(struct rspamd_task *task,
+ const char *class_name,
+ lua_State *L,
+ const char *classifier,
+ unsigned int stage,
+ GError **err)
+{
+ struct rspamd_stat_ctx *st_ctx;
+ rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
+
+ /*
+ * We assume now that a task has been already classified before
+ * coming to learn
+ */
+ g_assert(RSPAMD_TASK_IS_CLASSIFIED(task));
+
+ st_ctx = rspamd_stat_get_ctx();
+ g_assert(st_ctx != NULL);
+
+ msg_debug_bayes("learn class stage %d has been called for class '%s'", stage, class_name);
+
+ if (st_ctx->classifiers->len == 0) {
+ msg_debug_bayes("no classifiers defined");
+ task->processed_stages |= stage;
+ return ret;
+ }
+
+ if (task->message == NULL) {
+ ret = RSPAMD_STAT_PROCESS_ERROR;
+ if (err && *err == NULL) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Trying to learn an empty message");
+ }
+
+ task->processed_stages |= stage;
+ return ret;
+ }
+
+ if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
+ /* Process classifiers - determine spam boolean for compatibility */
+ gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+ rspamd_stat_preprocess(st_ctx, task, TRUE, spam);
+
+ if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) {
+ msg_debug_bayes("cache check failed, skip learning");
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+ else if (stage == RSPAMD_TASK_STAGE_LEARN) {
+ /* Process classifiers */
+ if (!rspamd_stat_classifiers_learn_class(st_ctx, task, classifier,
+ class_name, err)) {
+ if (err && *err == NULL) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Unknown statistics error, found when learning classifiers;"
+ " classifier: %s",
+ task->classifier);
+ }
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+
+ /* Process backends - determine spam boolean for compatibility */
+ gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+ if (!rspamd_stat_backends_learn(st_ctx, task, classifier, spam, err)) {
+ if (err && *err == NULL) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Unknown statistics error, found when storing data on backend;"
+ " classifier: %s",
+ task->classifier);
+ }
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+ else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) {
+ /* Process backends - determine spam boolean for compatibility */
+ gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+ if (!rspamd_stat_backends_post_learn(st_ctx, task, classifier, spam, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+
+ task->processed_stages |= stage;
+
+ return ret;
+}
+
rspamd_stat_result_t
rspamd_stat_learn(struct rspamd_task *task,
gboolean spam, lua_State *L, const char *classifier, unsigned int stage,
if (mres) {
if (mres->score > rspamd_task_get_required_score(task, mres)) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
-
+ rspamd_task_set_autolearn_class(task, "spam");
ret = TRUE;
}
else if (mres->score < 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ rspamd_task_set_autolearn_class(task, "ham");
ret = TRUE;
}
}
if (mres) {
if (mres->score >= spam_score) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
-
+ rspamd_task_set_autolearn_class(task, "spam");
ret = TRUE;
}
else if (mres->score <= ham_score) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ rspamd_task_set_autolearn_class(task, "ham");
ret = TRUE;
}
}
/* We can have immediate results */
if (lua_ret) {
if (strcmp(lua_ret, "ham") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ rspamd_task_set_autolearn_class(task, "ham");
ret = TRUE;
}
else if (strcmp(lua_ret, "spam") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+ rspamd_task_set_autolearn_class(task, "spam");
+ ret = TRUE;
+ }
+ else {
+ /* Multi-class: any other class name */
+ rspamd_task_set_autolearn_class(task, lua_ret);
ret = TRUE;
}
}
}
}
else if (ucl_object_type(obj) == UCL_OBJECT) {
- /* Try to find autolearn callback */
- if (cl->autolearn_cbref == 0) {
- /* We don't have preprocessed cb id, so try to get it */
- if (!rspamd_lua_require_function(L, "lua_bayes_learn",
- "autolearn")) {
- msg_err_task("cannot get autolearn library from "
- "`lua_bayes_learn`");
- }
- else {
- cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX);
+ /* Check if this is a multi-class autolearn configuration */
+ const ucl_object_t *multiclass_obj = ucl_object_lookup(obj, "multiclass");
+
+ if (multiclass_obj && ucl_object_type(multiclass_obj) == UCL_OBJECT) {
+ /* Multi-class threshold-based autolearn */
+ const ucl_object_t *thresholds_obj = ucl_object_lookup(multiclass_obj, "thresholds");
+
+ if (thresholds_obj && ucl_object_type(thresholds_obj) == UCL_OBJECT) {
+ /* Iterate through class thresholds */
+ ucl_object_iter_t it = NULL;
+ const ucl_object_t *class_obj;
+ const char *class_name;
+
+ while ((class_obj = ucl_object_iterate(thresholds_obj, &it, true))) {
+ class_name = ucl_object_key(class_obj);
+
+ if (class_name && ucl_object_type(class_obj) == UCL_ARRAY && class_obj->len == 2) {
+ /* [min_score, max_score] for this class */
+ const ucl_object_t *min_elt = ucl_array_find_index(class_obj, 0);
+ const ucl_object_t *max_elt = ucl_array_find_index(class_obj, 1);
+
+ if ((ucl_object_type(min_elt) == UCL_FLOAT || ucl_object_type(min_elt) == UCL_INT) &&
+ (ucl_object_type(max_elt) == UCL_FLOAT || ucl_object_type(max_elt) == UCL_INT)) {
+
+ double min_score = ucl_object_todouble(min_elt);
+ double max_score = ucl_object_todouble(max_elt);
+
+ if (mres && mres->score >= min_score && mres->score <= max_score) {
+ rspamd_task_set_autolearn_class(task, class_name);
+ ret = TRUE;
+ msg_debug_bayes("multiclass autolearn: score %.2f matches class '%s' [%.2f, %.2f]",
+ mres->score, class_name, min_score, max_score);
+ break; /* Stop at first matching class */
+ }
+ }
+ }
+ }
}
}
-
- if (cl->autolearn_cbref != -1) {
- lua_pushcfunction(L, &rspamd_lua_traceback);
- err_idx = lua_gettop(L);
- lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
-
- ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
- *ptask = task;
- rspamd_lua_setclass(L, rspamd_task_classname, -1);
- /* Push the whole object as well */
- ucl_object_push_lua(L, obj, true);
-
- if (lua_pcall(L, 2, 1, err_idx) != 0) {
- msg_err_task("call to autolearn script failed: "
- "%s",
- lua_tostring(L, -1));
+ else {
+ /* Try to find autolearn callback */
+ if (cl->autolearn_cbref == 0) {
+ /* We don't have preprocessed cb id, so try to get it */
+ if (!rspamd_lua_require_function(L, "lua_bayes_learn",
+ "autolearn")) {
+ msg_err_task("cannot get autolearn library from "
+ "`lua_bayes_learn`");
+ }
+ else {
+ cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX);
+ }
}
- else {
- lua_ret = lua_tostring(L, -1);
- if (lua_ret) {
- if (strcmp(lua_ret, "ham") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
- ret = TRUE;
- }
- else if (strcmp(lua_ret, "spam") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
- ret = TRUE;
+ if (cl->autolearn_cbref != -1) {
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+ lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
+
+ ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass(L, rspamd_task_classname, -1);
+ /* Push the whole object as well */
+ ucl_object_push_lua(L, obj, true);
+
+ if (lua_pcall(L, 2, 1, err_idx) != 0) {
+ msg_err_task("call to autolearn script failed: "
+ "%s",
+ lua_tostring(L, -1));
+ }
+ else {
+ lua_ret = lua_tostring(L, -1);
+
+ if (lua_ret) {
+ if (strcmp(lua_ret, "ham") == 0) {
+ rspamd_task_set_autolearn_class(task, "ham");
+ ret = TRUE;
+ }
+ else if (strcmp(lua_ret, "spam") == 0) {
+ rspamd_task_set_autolearn_class(task, "spam");
+ ret = TRUE;
+ }
+ else {
+ /* Multi-class: any other class name */
+ rspamd_task_set_autolearn_class(task, lua_ret);
+ ret = TRUE;
+ }
}
}
- }
- lua_settop(L, err_idx - 1);
+ lua_settop(L, err_idx - 1);
+ }
}
- }
- if (ret) {
- /* Do not autolearn if we have this symbol already */
- if (rspamd_stat_has_classifier_symbols(task, mres, cl)) {
- ret = FALSE;
- task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM |
- RSPAMD_TASK_FLAG_LEARN_SPAM);
- }
- else if (mres != NULL) {
- if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
- msg_info_task("<%s>: autolearn ham for classifier "
- "'%s' as message's "
- "score is negative: %.2f",
- MESSAGE_FIELD(task, message_id), cl->cfg->name,
- mres->score);
- }
- else {
- msg_info_task("<%s>: autolearn spam for classifier "
- "'%s' as message's "
- "action is reject, score: %.2f",
- MESSAGE_FIELD(task, message_id), cl->cfg->name,
- mres->score);
+ if (ret) {
+ /* Do not autolearn if we have this symbol already */
+ if (rspamd_stat_has_classifier_symbols(task, mres, cl)) {
+ ret = FALSE;
+ task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM |
+ RSPAMD_TASK_FLAG_LEARN_SPAM |
+ RSPAMD_TASK_FLAG_LEARN_CLASS);
+ /* Clear the autolearn class from mempool */
+ rspamd_mempool_set_variable(task->task_pool, "autolearn_class", NULL, NULL);
}
+ else if (mres != NULL) {
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+
+ if (autolearn_class) {
+ if (strcmp(autolearn_class, "ham") == 0) {
+ msg_info_task("<%s>: autolearn ham for classifier "
+ "'%s' as message's "
+ "score is negative: %.2f",
+ MESSAGE_FIELD(task, message_id), cl->cfg->name,
+ mres->score);
+ }
+ else if (strcmp(autolearn_class, "spam") == 0) {
+ msg_info_task("<%s>: autolearn spam for classifier "
+ "'%s' as message's "
+ "action is reject, score: %.2f",
+ MESSAGE_FIELD(task, message_id), cl->cfg->name,
+ mres->score);
+ }
+ else {
+ msg_info_task("<%s>: autolearn class '%s' for classifier "
+ "'%s', score: %.2f",
+ MESSAGE_FIELD(task, message_id), autolearn_class,
+ cl->cfg->name, mres->score);
+ }
+ }
- task->classifier = cl->cfg->name;
- break;
+ task->classifier = cl->cfg->name;
+ break;
+ }
}
}
}