-- Lua script to perform bayes classification (multi-class)
-- This script accepts the following parameters:
-- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - class labels: either table of all class labels (multi-class) or single string (binary)
+-- key2 - class labels: table of all class labels as "TABLE:label1,label2,..."
-- key3 - set of tokens encoded in messagepack array of strings
local prefix = KEYS[1]
local class_labels_arg = KEYS[2]
local input_tokens = cmsgpack.unpack(KEYS[3])
--- Determine if this is multi-class (table) or binary (string)
+-- Parse class labels (always expect TABLE: format)
local class_labels = {}
-
--- Check if this is a table serialized as "TABLE:label1,label2,..."
if string.match(class_labels_arg, "^TABLE:") then
local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:" prefix
- -- Split by comma
for label in string.gmatch(labels_str, "([^,]+)") do
table.insert(class_labels, label)
end
else
- -- Binary compatibility: handle old boolean or single string format
- if class_labels_arg == "true" then
- class_labels = { "S" } -- spam
- elseif class_labels_arg == "false" then
- class_labels = { "H" } -- ham
- else
- class_labels = { class_labels_arg } -- single class label
- end
+ -- Legacy single class - convert to array
+ class_labels = { class_labels_arg }
end
--- Get learned counts for all classes
+-- Get learned counts for all classes (ordered)
local learned_counts = {}
for _, label in ipairs(class_labels) do
local key = 'learns_' .. string.lower(label)
- -- Also try legacy keys for backward compatibility
+ -- Handle legacy keys for backward compatibility
if label == 'H' then
key = 'learns_ham'
elseif label == 'S' then
key = 'learns_spam'
end
- learned_counts[label] = tonumber(redis.call('HGET', prefix, key)) or 0
+ table.insert(learned_counts, tonumber(redis.call('HGET', prefix, key)) or 0)
+end
+
+-- Get token data for all classes (ordered)
+local token_results = {}
+for i, label in ipairs(class_labels) do
+ token_results[i] = {}
end
--- Get token data for all classes (only if we have learns for any class)
-local outputs = {}
+-- Check if we have any learning data
local has_learns = false
-for _, count in pairs(learned_counts) do
+for _, count in ipairs(learned_counts) do
if count > 0 then
has_learns = true
break
end
if has_learns then
- -- Initialize outputs for each class
- for _, label in ipairs(class_labels) do
- outputs[label] = {}
- end
-
-- Process each token
for i, token in ipairs(input_tokens) do
local token_data = redis.call('HMGET', token, unpack(class_labels))
if token_data then
for j, label in ipairs(class_labels) do
local count = token_data[j]
- if count then
- table.insert(outputs[label], { i, tonumber(count) })
+ if count and tonumber(count) > 0 then
+ table.insert(token_results[j], { i, tonumber(count) })
end
end
end
end
end
--- Format output for backward compatibility
-if #class_labels == 2 and class_labels[1] == 'H' and class_labels[2] == 'S' then
- -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
- return {
- learned_counts['H'] or 0,
- learned_counts['S'] or 0,
- outputs['H'] or {},
- outputs['S'] or {}
- }
-elseif #class_labels == 2 and class_labels[1] == 'S' and class_labels[2] == 'H' then
- -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
- return {
- learned_counts['H'] or 0,
- learned_counts['S'] or 0,
- outputs['H'] or {},
- outputs['S'] or {}
- }
-else
- -- Multi-class format: [learned_counts_table, outputs_table]
- return { learned_counts, outputs }
-end
+-- Always return ordered arrays: [learned_counts_array, token_results_array]
+return { learned_counts, token_results }
static const char *helo = nullptr;
static const char *hostname = nullptr;
static const char *classifier = nullptr;
+static const char *learn_class_name = nullptr;
static const char *local_addr = nullptr;
static const char *execute = nullptr;
static const char *sort = nullptr;
RSPAMC_COMMAND_SYMBOLS,
RSPAMC_COMMAND_LEARN_SPAM,
RSPAMC_COMMAND_LEARN_HAM,
+ RSPAMC_COMMAND_LEARN_CLASS,
RSPAMC_COMMAND_FUZZY_ADD,
RSPAMC_COMMAND_FUZZY_DEL,
RSPAMC_COMMAND_FUZZY_DELHASH,
.is_privileged = TRUE,
.need_input = TRUE,
.command_output_func = nullptr},
+ rspamc_command{
+ .cmd = RSPAMC_COMMAND_LEARN_CLASS,
+ .name = "learn_class",
+ .path = "learnclass",
+ .description = "learn message as class",
+ .is_controller = TRUE,
+ .is_privileged = TRUE,
+ .need_input = TRUE,
+ .command_output_func = nullptr},
rspamc_command{
.cmd = RSPAMC_COMMAND_FUZZY_ADD,
.name = "fuzzy_add",
auto *map = (char *) locked_mmap.value().get_map();
value_view = std::string_view{map, locked_mmap->get_size()};
auto right = value_view.end() - 1;
- for (; right > value_view.cbegin() && g_ascii_isspace(*right); --right)
- ;
+ for (; right > value_view.cbegin() && g_ascii_isspace(*right); --right);
std::string_view str{value_view.begin(), static_cast<size_t>(right - value_view.begin()) + 1};
processed_passwd.assign(std::begin(str), std::end(str));
processed_passwd.push_back('\0'); /* Null-terminate for C part */
{"report", RSPAMC_COMMAND_SYMBOLS},
{"learn_spam", RSPAMC_COMMAND_LEARN_SPAM},
{"learn_ham", RSPAMC_COMMAND_LEARN_HAM},
+ {"learn_class", RSPAMC_COMMAND_LEARN_CLASS},
{"fuzzy_add", RSPAMC_COMMAND_FUZZY_ADD},
{"fuzzy_del", RSPAMC_COMMAND_FUZZY_DEL},
{"fuzzy_delhash", RSPAMC_COMMAND_FUZZY_DELHASH},
});
std::string cmd_lc = rspamd_string_tolower(cmd);
+
+ // Handle learn_class:classname syntax
+ if (cmd_lc.find("learn_class:") == 0) {
+ auto colon_pos = cmd_lc.find(':');
+ if (colon_pos != std::string::npos && colon_pos + 1 < cmd_lc.length()) {
+ auto class_name = cmd_lc.substr(colon_pos + 1);
+ // Store class name globally for later use
+ learn_class_name = g_strdup(class_name.c_str());
+ // Return the learn_class command
+ auto elt_it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&](const auto &item) {
+ return item.cmd == RSPAMC_COMMAND_LEARN_CLASS;
+ });
+ if (elt_it != std::end(rspamc_commands)) {
+ return *elt_it;
+ }
+ }
+ return std::nullopt;
+ }
+
auto ct = rspamd::find_map(str_map, std::string_view{cmd_lc});
+ if (!ct.has_value()) {
+ return std::nullopt;
+ }
+
auto elt_it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&](const auto &item) {
- return item.cmd == ct;
+ return item.cmd == ct.value();
});
if (elt_it != std::end(rspamc_commands)) {
add_client_header(opts, "Classifier", classifier);
}
+ if (learn_class_name) {
+ add_client_header(opts, "Class", learn_class_name);
+ }
+
if (weight != 0) {
auto nstr = fmt::format("{}", weight);
add_client_header(opts, "Weight", nstr.c_str());
if (raw_body) {
/* We can also output the resulting json */
- rspamc_print(out, "{}\n", std::string_view{raw_body, (std::size_t)(rawlen - bodylen)});
+ rspamc_print(out, "{}\n", std::string_view{raw_body, (std::size_t) (rawlen - bodylen)});
}
}
}
p = strrchr(connect_str, ']');
if (p != nullptr) {
- hostbuf.assign(connect_str + 1, (std::size_t)(p - connect_str - 1));
+ hostbuf.assign(connect_str + 1, (std::size_t) (p - connect_str - 1));
p++;
}
else {
if (hostbuf.empty()) {
if (p != nullptr) {
- hostbuf.assign(connect_str, (std::size_t)(p - connect_str));
+ hostbuf.assign(connect_str, (std::size_t) (p - connect_str));
}
else {
hostbuf.assign(connect_str);
#define PATH_HISTORY_RESET "/historyreset"
#define PATH_LEARN_SPAM "/learnspam"
#define PATH_LEARN_HAM "/learnham"
+#define PATH_LEARN_CLASS "/learnclass"
#define PATH_METRICS "/metrics"
#define PATH_READY "/ready"
#define PATH_SAVE_ACTIONS "/saveactions"
struct rspamd_controller_worker_ctx *ctx;
struct rspamd_task *task;
const rspamd_ftok_t *cl_header;
+ const char *class_name;
ctx = session->ctx;
goto end;
}
- rspamd_learn_task_spam(task, is_spam, session->classifier, NULL);
+ /* Use unified class-based learning approach */
+ class_name = is_spam ? "spam" : "ham";
+ rspamd_task_set_autolearn_class(task, class_name);
if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) {
msg_warn_session("<%s> message cannot be processed",
return rspamd_controller_handle_learn_common(conn_ent, msg, FALSE);
}
+/*
+ * Learn class command handler:
+ * request: /learnclass
+ * headers: Password, Class
+ * input: plaintext data
+ * reply: json {"success":true} or {"error":"error message"}
+ */
+static int
+rspamd_controller_handle_learnclass(
+ struct rspamd_http_connection_entry *conn_ent,
+ struct rspamd_http_message *msg)
+{
+ struct rspamd_controller_session *session = conn_ent->ud;
+ struct rspamd_controller_worker_ctx *ctx;
+ struct rspamd_task *task;
+ const rspamd_ftok_t *cl_header, *class_header;
+ char *class_name = NULL;
+
+ ctx = session->ctx;
+
+ if (!rspamd_controller_check_password(conn_ent, session, msg, TRUE)) {
+ return 0;
+ }
+
+ if (rspamd_http_message_get_body(msg, NULL) == NULL) {
+ msg_err_session("got zero length body, cannot continue");
+ rspamd_controller_send_error(conn_ent,
+ 400,
+ "Empty body is not permitted");
+ return 0;
+ }
+
+ class_header = rspamd_http_message_find_header(msg, "Class");
+ if (!class_header) {
+ msg_err_session("missing Class header for multiclass learning");
+ rspamd_controller_send_error(conn_ent,
+ 400,
+ "Class header is required for multiclass learning");
+ return 0;
+ }
+
+ task = rspamd_task_new(session->ctx->worker, session->cfg, session->pool,
+ session->ctx->lang_det, ctx->event_loop, FALSE);
+
+ task->resolver = ctx->resolver;
+ task->s = rspamd_session_create(session->pool,
+ rspamd_controller_learn_fin_task,
+ NULL,
+ (event_finalizer_t) rspamd_task_free,
+ task);
+ task->fin_arg = conn_ent;
+ task->http_conn = rspamd_http_connection_ref(conn_ent->conn);
+ task->sock = -1;
+ session->task = task;
+
+ cl_header = rspamd_http_message_find_header(msg, "classifier");
+ if (cl_header) {
+ session->classifier = rspamd_mempool_ftokdup(session->pool, cl_header);
+ }
+ else {
+ session->classifier = NULL;
+ }
+
+ if (!rspamd_task_load_message(task, msg, msg->body_buf.begin, msg->body_buf.len)) {
+ goto end;
+ }
+
+ /* Set multiclass learning flag and store class name */
+ class_name = rspamd_mempool_ftokdup(task->task_pool, class_header);
+ rspamd_task_set_autolearn_class(task, class_name);
+
+ if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) {
+ msg_warn_session("<%s> message cannot be processed",
+ MESSAGE_FIELD_CHECK(task, message_id));
+ goto end;
+ }
+
+end:
+ /* Set session spam flag for logging compatibility */
+ if (class_name) {
+ session->is_spam = (strcmp(class_name, "spam") == 0);
+ }
+ else {
+ session->is_spam = FALSE;
+ }
+ rspamd_session_pending(task->s);
+
+ return 0;
+}
+
/*
* Scan command handler:
* request: /scan
rspamd_http_message_add_header(rep, "Access-Control-Allow-Methods",
"POST, GET, OPTIONS");
rspamd_http_message_add_header(rep, "Access-Control-Allow-Headers",
- "Classifier,Content-Type,Password,Map,Weight,Flag,Hash");
+ "Classifier,Class,Content-Type,Password,Map,Weight,Flag,Hash");
rspamd_http_connection_reset(conn_ent->conn);
rspamd_http_router_insert_headers(conn_ent->rt, rep);
rspamd_http_connection_write_message(conn_ent->conn,
*/
static int
rspamd_controller_handle_bayes_classifiers(struct rspamd_http_connection_entry *conn_ent,
- struct rspamd_http_message *msg)
+ struct rspamd_http_message *msg)
{
struct rspamd_controller_session *session = conn_ent->ud;
struct rspamd_controller_worker_ctx *ctx = session->ctx;
rspamd_http_router_add_path(ctx->http,
PATH_LEARN_HAM,
rspamd_controller_handle_learnham);
+ rspamd_http_router_add_path(ctx->http,
+ PATH_LEARN_CLASS,
+ rspamd_controller_handle_learnclass);
rspamd_http_router_add_path(ctx->http,
PATH_METRICS,
rspamd_controller_handle_metrics);
char *label; /**< label of this statfile */
ucl_object_t *opts; /**< other options */
char *class_name; /**< class name for multi-class classification */
+ unsigned int class_index; /**< class index for O(1) lookup during classification */
gboolean is_spam; /**< DEPRECATED: spam flag - use class_name instead */
struct rspamd_classifier_config *clcf; /**< parent pointer of classifier configuration */
gpointer data; /**< opaque data */
cfg->classifiers = g_list_prepend(cfg->classifiers, ccf);
+ /* Populate class_names array from statfiles */
+ if (ccf->statfiles) {
+ GList *cur = ccf->statfiles;
+ ccf->class_names = g_ptr_array_new();
+
+ while (cur) {
+ struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name) {
+ /* Check if class already exists */
+ bool found = false;
+ for (unsigned int i = 0; i < ccf->class_names->len; i++) {
+ if (strcmp((char *) g_ptr_array_index(ccf->class_names, i), stcf->class_name) == 0) {
+ stcf->class_index = i; /* Store the index for O(1) lookup */
+ found = true;
+ break;
+ }
+ }
+
+ if (!found) {
+ /* Add new class */
+ stcf->class_index = ccf->class_names->len;
+ g_ptr_array_add(ccf->class_names, g_strdup(stcf->class_name));
+ }
+ }
+ cur = g_list_next(cur);
+ }
+ }
+
return TRUE;
}
const char *classifier,
GError **err)
{
+ /* Use unified class-based approach internally */
+ const char *class_name = is_spam ? "spam" : "ham";
+
/* Disable learn auto flag to avoid bad learn codes */
task->flags &= ~RSPAMD_TASK_FLAG_LEARN_AUTO;
- if (is_spam) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
- }
- else {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
- }
+ /* Use the unified class-based learning approach */
+ rspamd_task_set_autolearn_class(task, class_name);
task->classifier = classifier;
#include "contrib/fmt/include/fmt/base.h"
#include "libutil/cxx/error.hxx"
+#include <map>
#include <string>
#include <cstdint>
rspamd_token_t *tok;
if (!results) {
+ msg_debug_bayes("process_tokens: no results available for statfile id=%d", id);
return false;
}
+ msg_debug_bayes("processing tokens for statfile id=%d, results size=%uz, class=%s",
+ id, results->size(), stcf->class_name ? stcf->class_name : "unknown");
+
for (auto [idx, val]: *results) {
tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1);
+ msg_debug_bayes("setting tok->values[%d] = %.2f for token idx %d (class=%s)",
+ id, val, idx, stcf->class_name ? stcf->class_name : "unknown");
tok->values[id] = val;
}
/* No cached result (or learn), create new one */
auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
- if (!learn && stcf->clcf && stcf->clcf->class_names && stcf->clcf->class_names->len > 2) {
- /*
- * For multi-class classification, we need to create runtimes for ALL classes
- * to avoid multiple Redis calls. The actual Redis call will fetch data for all classes.
- */
+ /* Find the statfile ID for the main runtime */
+ int main_id = _id; /* Use the passed _id parameter */
+ rt->id = main_id;
+ rt->stcf = stcf;
+
+ /* For classification, create runtimes for all other statfiles to avoid multiple Redis calls */
+ if (!learn && stcf->clcf && stcf->clcf->statfiles) {
GList *cur = stcf->clcf->statfiles;
+
while (cur) {
auto *other_stcf = (struct rspamd_statfile_config *) cur->data;
- if (other_stcf != stcf) {
- const char *other_label = get_class_label(other_stcf);
-
- auto maybe_other_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- object_expanded, other_label);
- if (!maybe_other_rt) {
- auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
- other_rt->save_in_mempool(other_label);
- other_rt->need_redis_call = false;
+ const char *other_label = get_class_label(other_stcf);
+
+ /* Find the statfile ID by searching through all statfiles */
+ struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+ int other_id = -1;
+ for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+ struct rspamd_statfile *st = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+ if (st->stcf == other_stcf) {
+ other_id = st->id;
+ msg_debug_bayes("found statfile mapping: %s (class=%s) → id=%d",
+ st->stcf->symbol, other_label, other_id);
+ break;
}
}
+
+ if (other_id == -1) {
+ msg_debug_bayes("statfile not found for class %s, skipping", other_label);
+ /* Skip if statfile not found */
+ cur = g_list_next(cur);
+ continue;
+ }
+
+ if (other_stcf == stcf) {
+ /* This is the main statfile, use the main runtime */
+ rt->save_in_mempool(other_label);
+ msg_debug_bayes("main runtime: statfile %s (class=%s) → id=%d",
+ stcf->symbol, other_label, rt->id);
+ }
+ else {
+ /* Create additional runtime for other statfile */
+ auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+ other_rt->id = other_id;
+ other_rt->stcf = other_stcf;
+ other_rt->save_in_mempool(other_label);
+ msg_debug_bayes("additional runtime: statfile %s (class=%s) → id=%d",
+ other_stcf->symbol, other_label, other_id);
+ }
+
cur = g_list_next(cur);
}
}
- else if (!learn) {
- /*
- * For binary classification, create the opposite class runtime to avoid
- * double call for Redis scripts (backward compatibility).
- */
- const char *opposite_label = stcf->is_spam ? "H" : "S";
- auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- object_expanded, opposite_label);
-
- if (!maybe_opposite_rt) {
- auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
- opposite_rt->save_in_mempool(opposite_label);
- opposite_rt->need_redis_call = false;
- }
+ else {
+ rt->save_in_mempool(class_label);
}
- rt->save_in_mempool(class_label);
-
return rt;
}
if (rt == nullptr) {
msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
-
return 0;
}
return 0;
}
- /* Check the array length to determine format:
- * - Length 4: binary format [learned_ham, learned_spam, ham_tokens, spam_tokens]
- * - Length 2: multi-class format [learned_counts_table, outputs_table]
- */
+ /* Redis returns [learned_counts_array, token_results_array]
+ * Both ordered the same way as statfiles in classifier */
size_t result_len = rspamd_lua_table_size(L, 3);
msg_debug_bayes("Redis result array length: %uz", result_len);
- auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
- rt->learned = learned;
- redis_stat_runtime<float>::result_type *res;
+ if (result_len != 2) {
+ msg_err_task("internal error: expected 2-element result from Redis script, got %uz", result_len);
+ rt->err = rspamd::util::error("invalid Redis script result format", 500);
+ return 0;
+ }
- res = new redis_stat_runtime<float>::result_type();
+ /* Get learned_counts_array and token_results_array */
+ lua_rawgeti(L, 3, 1); /* learned_counts -> position 4 */
+ lua_rawgeti(L, 3, 2); /* token_results -> position 5 */
- for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) {
- lua_rawgeti(L, -1, 1);
- auto idx = lua_tointeger(L, -1);
- lua_pop(L, 1);
+ /* Process results for all statfiles in order using class_index (O(N) instead of O(N²)) */
+ if (rt->stcf->clcf && rt->stcf->clcf->statfiles) {
+ GList *cur = rt->stcf->clcf->statfiles;
+ int redis_idx = 1; /* Redis result array index (1-based) */
- lua_rawgeti(L, -1, 2);
- auto value = lua_tonumber(L, -1);
- lua_pop(L, 1);
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
- res->emplace_back(idx, value);
- }
+ /* Direct statfile lookup using global statfiles array */
+ struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+ struct rspamd_statfile *st = nullptr;
+
+ /* Find statfile by config pointer (still O(N) but unavoidable) */
+ for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+ struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+ if (candidate->stcf == stcf) {
+ st = candidate;
+ break;
+ }
+ }
+
+ if (!st) {
+ msg_debug_bayes("statfile not found for config %s, skipping", stcf->symbol);
+ cur = g_list_next(cur);
+ redis_idx++;
+ continue;
+ }
- rt->set_results(res);
- };
+ /* Get or create runtime for this statfile */
+ auto *statfile_rt = rt; /* Use current runtime for first statfile */
+ if (stcf != rt->stcf) {
+ const char *class_label = get_class_label(stcf);
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ statfile_rt = maybe_rt.value();
+ }
+ else {
+ msg_debug_bayes("runtime not found for class %s, skipping", class_label);
+ cur = g_list_next(cur);
+ redis_idx++;
+ continue;
+ }
+ }
- bool is_binary_format = (result_len == 4);
+ /* Ensure correct statfile ID assignment */
+ statfile_rt->id = st->id;
+
+ /* Process token results for this statfile (Redis array index redis_idx) */
+ lua_rawgeti(L, 5, redis_idx); /* Get token_results[redis_idx] */
+ if (lua_istable(L, -1)) {
+ /* Parse token results into statfile runtime */
+ auto *res = new std::vector<std::pair<int, float>>();
+
+ lua_pushnil(L); /* First key for iteration */
+ while (lua_next(L, -2) != 0) {
+ if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) {
+ lua_rawgeti(L, -1, 1); /* token_index */
+ lua_rawgeti(L, -2, 2); /* token_count */
+
+ if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) {
+ int token_idx = lua_tointeger(L, -2);
+ float token_count = lua_tonumber(L, -1);
+ res->emplace_back(token_idx, token_count);
+ }
+
+ lua_pop(L, 2); /* Pop token_index and token_count */
+ }
+ lua_pop(L, 1); /* Pop value, keep key for next iteration */
+ }
- if (is_binary_format) {
- /* Binary format: [learned_ham, learned_spam, ham_tokens, spam_tokens] */
+ statfile_rt->set_results(res);
+ }
+ lua_pop(L, 1); /* Pop token_results[redis_idx] */
- /* Find the opposite runtime for binary classification compatibility */
- const char *opposite_label;
- if (rt->stcf->class_name) {
- /* Multi-class: find a different class (simplified for now) */
- opposite_label = strcmp(get_class_label(rt->stcf), "S") == 0 ? "H" : "S";
+ cur = g_list_next(cur);
+ redis_idx++;
}
- else {
- /* Binary: use opposite spam/ham */
- opposite_label = rt->stcf->is_spam ? "H" : "S";
- }
- auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- rt->redis_object_expanded,
- opposite_label);
+ }
- if (!opposite_rt_maybe) {
- msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
- return 0;
- }
+ /* Clean up stack */
+ lua_pop(L, 2); /* Pop learned_counts and token_results */
- /* Extract values from the result table at position 3 */
- lua_rawgeti(L, 3, 1); /* learned_ham -> position 4 */
- lua_rawgeti(L, 3, 2); /* learned_spam -> position 5 */
- lua_rawgeti(L, 3, 3); /* ham_tokens -> position 6 */
- lua_rawgeti(L, 3, 4); /* spam_tokens -> position 7 */
+ /* Process tokens for all runtimes */
+ g_assert(rt->tokens != nullptr);
- unsigned learned_ham = lua_tointeger(L, 4);
- unsigned learned_spam = lua_tointeger(L, 5);
+ /* Process tokens for all statfiles */
+ if (rt->stcf->clcf && rt->stcf->clcf->statfiles) {
+ GList *cur = rt->stcf->clcf->statfiles;
- if (rt->stcf->is_spam || (rt->stcf->class_name && strcmp(get_class_label(rt->stcf), "S") == 0)) {
- /* Current runtime is spam, use spam data */
- filler_func(rt, L, learned_spam, 7); /* spam_tokens at position 7 */
- filler_func(opposite_rt_maybe.value(), L, learned_ham, 6); /* ham_tokens at position 6 */
- }
- else {
- /* Current runtime is ham, use ham data */
- filler_func(rt, L, learned_ham, 6); /* ham_tokens at position 6 */
- filler_func(opposite_rt_maybe.value(), L, learned_spam, 7); /* spam_tokens at position 7 */
- }
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ const char *class_label = get_class_label(stcf);
- /* Clean up the stack - pop the 4 extracted values */
- lua_pop(L, 4);
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ auto *statfile_rt = maybe_rt.value();
+ statfile_rt->process_tokens(rt->tokens);
+ }
- /* Process all tokens */
- g_assert(rt->tokens != nullptr);
- rt->process_tokens(rt->tokens);
- opposite_rt_maybe.value()->process_tokens(rt->tokens);
+ cur = g_list_next(cur);
+ }
}
else {
- /* Multi-class format: [learned_counts_table, outputs_table] */
-
- /* Get learned counts table (index 1) and outputs table (index 2) */
- lua_rawgeti(L, 3, 1); /* learned_counts */
- lua_rawgeti(L, 3, 2); /* outputs */
-
- /* Iterate through all class labels to fill all runtimes */
- if (rt->stcf->clcf && rt->stcf->clcf->class_labels) {
- GHashTableIter iter;
- gpointer key, value;
- g_hash_table_iter_init(&iter, rt->stcf->clcf->class_labels);
-
- while (g_hash_table_iter_next(&iter, &key, &value)) {
- const char *class_label = (const char *) value;
-
- /* Find runtime for this class */
- auto class_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- rt->redis_object_expanded,
- class_label);
-
- if (class_rt_maybe) {
- auto *class_rt = class_rt_maybe.value();
-
- /* Get learned count for this class */
- lua_pushstring(L, class_label);
- lua_gettable(L, -3); /* learned_counts[class_label] */
- unsigned learned = lua_tointeger(L, -1);
- lua_pop(L, 1);
-
- /* Get outputs for this class */
- lua_pushstring(L, class_label);
- lua_gettable(L, -2); /* outputs[class_label] */
- int outputs_pos = lua_gettop(L);
-
- filler_func(class_rt, L, learned, outputs_pos);
- lua_pop(L, 1);
- }
- }
- }
-
- lua_pop(L, 2); /* Pop learned_counts and outputs tables */
-
- /* Process tokens for all runtimes */
- g_assert(rt->tokens != nullptr);
+ /* Fallback: just process the main runtime */
rt->process_tokens(rt->tokens);
}
-
- /* Tokens processed - no need to set flags in multi-class approach */
}
else {
/* Error message is on index 3 */
val = tok->values[id];
if (val > 0) {
- /* Find which class this statfile belongs to */
- for (j = 0; j < cl->num_classes; j++) {
- if (st->stcf->class_name &&
- strcmp(st->stcf->class_name, cl->class_names[j]) == 0) {
- class_counts[j] += val;
- total_count += val;
- cl->total_hits += val;
- break;
- }
+ /* Direct O(1) class index lookup instead of O(N) string comparison */
+ if (st->stcf->class_name && st->stcf->class_index < cl->num_classes) {
+ unsigned int class_idx = st->stcf->class_index;
+ class_counts[class_idx] += val;
+ total_count += val;
+ cl->total_hits += val;
+ }
+ else {
+ msg_debug_bayes("invalid class_index %ud >= %ud for statfile %s",
+ st->stcf->class_index, cl->num_classes, st->stcf->symbol);
}
}
}
if (tok->t1 && tok->t2) {
msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %.3f, total_count: %ud, "
- "processed for %u classes",
+ "processed for %ud classes",
token_type, tok->data,
(int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
(int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
cl.cfg = ctx->cfg;
/* Get class information from classifier config */
- if (!ctx->cfg->class_names || ctx->cfg->class_names->len < 2) {
- msg_debug_bayes("insufficient classes for multiclass classification");
+ if (!ctx->cfg->class_names) {
+ msg_debug_bayes("no class_names array in classifier config");
+ return TRUE; /* Fall back to binary mode */
+ }
+ if (ctx->cfg->class_names->len < 2) {
+ msg_debug_bayes("insufficient classes: %ud < 2", (unsigned int) ctx->cfg->class_names->len);
+ return TRUE; /* Fall back to binary mode */
+ }
+ if (!ctx->cfg->class_names->pdata) {
+ msg_debug_bayes("class_names->pdata is NULL");
return TRUE; /* Fall back to binary mode */
}
cl.num_classes = ctx->cfg->class_names->len;
cl.class_names = (char **) ctx->cfg->class_names->pdata;
+
+ /* Debug: verify class names are accessible */
+ msg_debug_bayes("multiclass setup: ctx->cfg->class_names=%p, len=%ud, pdata=%p",
+ ctx->cfg->class_names, (unsigned int) ctx->cfg->class_names->len, ctx->cfg->class_names->pdata);
+ msg_debug_bayes("multiclass setup: cl.num_classes=%ud, cl.class_names=%p",
+ cl.num_classes, cl.class_names);
cl.class_log_probs = g_alloca(cl.num_classes * sizeof(double));
cl.class_learns = g_alloca(cl.num_classes * sizeof(uint64_t));
}
if (cl.processed_tokens == 0) {
+ /* Debug: check why no tokens were processed */
+ msg_debug_bayes("examining token values for debugging:");
+ for (i = 0; i < MIN(tokens->len, 10); i++) { /* Check first 10 tokens */
+ tok = g_ptr_array_index(tokens, i);
+ for (j = 0; j < ctx->statfiles_ids->len; j++) {
+ id = g_array_index(ctx->statfiles_ids, int, j);
+ if (tok->values[id] > 0) {
+ struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ msg_debug_bayes("token %ud: values[%d] = %.2f (class=%s, symbol=%s)",
+ i, id, tok->values[id],
+ st->stcf->class_name ? st->stcf->class_name : "unknown",
+ st->stcf->symbol);
+ }
+ }
+ }
+
msg_info_bayes("no tokens found in bayes database "
"(%ud total tokens, %ud text tokens), ignore stats",
tokens->len, text_tokens);