From: Vsevolod Stakhov Date: Tue, 22 Jul 2025 21:30:03 +0000 (+0100) Subject: [Project] Further updates X-Git-Tag: 3.13.0~38^2~20 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=ab11cd145f785d6b7cc606a5410cddc70c2db5dd;p=thirdparty%2Frspamd.git [Project] Further updates --- diff --git a/lualib/redis_scripts/bayes_classify.lua b/lualib/redis_scripts/bayes_classify.lua index e07b9a7956..923adcc5ad 100644 --- a/lualib/redis_scripts/bayes_classify.lua +++ b/lualib/redis_scripts/bayes_classify.lua @@ -1,51 +1,47 @@ -- Lua script to perform bayes classification (multi-class) -- This script accepts the following parameters: -- key1 - prefix for bayes tokens (e.g. for per-user classification) --- key2 - class labels: either table of all class labels (multi-class) or single string (binary) +-- key2 - class labels: table of all class labels as "TABLE:label1,label2,..." -- key3 - set of tokens encoded in messagepack array of strings local prefix = KEYS[1] local class_labels_arg = KEYS[2] local input_tokens = cmsgpack.unpack(KEYS[3]) --- Determine if this is multi-class (table) or binary (string) +-- Parse class labels (always expect TABLE: format) local class_labels = {} - --- Check if this is a table serialized as "TABLE:label1,label2,..." if string.match(class_labels_arg, "^TABLE:") then local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:" prefix - -- Split by comma for label in string.gmatch(labels_str, "([^,]+)") do table.insert(class_labels, label) end else - -- Binary compatibility: handle old boolean or single string format - if class_labels_arg == "true" then - class_labels = { "S" } -- spam - elseif class_labels_arg == "false" then - class_labels = { "H" } -- ham - else - class_labels = { class_labels_arg } -- single class label - end + -- Legacy single class - convert to array + class_labels = { class_labels_arg } end --- Get learned counts for all classes +-- Get learned counts for all classes (ordered) local learned_counts = {} for _, label in ipairs(class_labels) do local key = 'learns_' .. string.lower(label) - -- Also try legacy keys for backward compatibility + -- Handle legacy keys for backward compatibility if label == 'H' then key = 'learns_ham' elseif label == 'S' then key = 'learns_spam' end - learned_counts[label] = tonumber(redis.call('HGET', prefix, key)) or 0 + table.insert(learned_counts, tonumber(redis.call('HGET', prefix, key)) or 0) +end + +-- Get token data for all classes (ordered) +local token_results = {} +for i, label in ipairs(class_labels) do + token_results[i] = {} end --- Get token data for all classes (only if we have learns for any class) -local outputs = {} +-- Check if we have any learning data local has_learns = false -for _, count in pairs(learned_counts) do +for _, count in ipairs(learned_counts) do if count > 0 then has_learns = true break @@ -53,11 +49,6 @@ for _, count in pairs(learned_counts) do end if has_learns then - -- Initialize outputs for each class - for _, label in ipairs(class_labels) do - outputs[label] = {} - end - -- Process each token for i, token in ipairs(input_tokens) do local token_data = redis.call('HMGET', token, unpack(class_labels)) @@ -65,32 +56,13 @@ if has_learns then if token_data then for j, label in ipairs(class_labels) do local count = token_data[j] - if count then - table.insert(outputs[label], { i, tonumber(count) }) + if count and tonumber(count) > 0 then + table.insert(token_results[j], { i, tonumber(count) }) end end end end end --- Format output for backward compatibility -if #class_labels == 2 and class_labels[1] == 'H' and class_labels[2] == 'S' then - -- Binary format: [learned_ham, learned_spam, output_ham, output_spam] - return { - learned_counts['H'] or 0, - learned_counts['S'] or 0, - outputs['H'] or {}, - outputs['S'] or {} - } -elseif #class_labels == 2 and class_labels[1] == 'S' and class_labels[2] == 'H' then - -- Binary format: [learned_ham, learned_spam, output_ham, output_spam] - return { - learned_counts['H'] or 0, - learned_counts['S'] or 0, - outputs['H'] or {}, - outputs['S'] or {} - } -else - -- Multi-class format: [learned_counts_table, outputs_table] - return { learned_counts, outputs } -end +-- Always return ordered arrays: [learned_counts_array, token_results_array] +return { learned_counts, token_results } diff --git a/src/client/rspamc.cxx b/src/client/rspamc.cxx index 4043598773..af88acb337 100644 --- a/src/client/rspamc.cxx +++ b/src/client/rspamc.cxx @@ -59,6 +59,7 @@ static const char *user = nullptr; static const char *helo = nullptr; static const char *hostname = nullptr; static const char *classifier = nullptr; +static const char *learn_class_name = nullptr; static const char *local_addr = nullptr; static const char *execute = nullptr; static const char *sort = nullptr; @@ -198,6 +199,7 @@ enum rspamc_command_type { RSPAMC_COMMAND_SYMBOLS, RSPAMC_COMMAND_LEARN_SPAM, RSPAMC_COMMAND_LEARN_HAM, + RSPAMC_COMMAND_LEARN_CLASS, RSPAMC_COMMAND_FUZZY_ADD, RSPAMC_COMMAND_FUZZY_DEL, RSPAMC_COMMAND_FUZZY_DELHASH, @@ -249,6 +251,15 @@ static const constexpr auto rspamc_commands = rspamd::array_of( .is_privileged = TRUE, .need_input = TRUE, .command_output_func = nullptr}, + rspamc_command{ + .cmd = RSPAMC_COMMAND_LEARN_CLASS, + .name = "learn_class", + .path = "learnclass", + .description = "learn message as class", + .is_controller = TRUE, + .is_privileged = TRUE, + .need_input = TRUE, + .command_output_func = nullptr}, rspamc_command{ .cmd = RSPAMC_COMMAND_FUZZY_ADD, .name = "fuzzy_add", @@ -527,8 +538,7 @@ rspamc_password_callback(const char *option_name, auto *map = (char *) locked_mmap.value().get_map(); value_view = std::string_view{map, locked_mmap->get_size()}; auto right = value_view.end() - 1; - for (; right > value_view.cbegin() && g_ascii_isspace(*right); --right) - ; + for (; right > value_view.cbegin() && g_ascii_isspace(*right); --right); std::string_view str{value_view.begin(), static_cast(right - value_view.begin()) + 1}; processed_passwd.assign(std::begin(str), std::end(str)); processed_passwd.push_back('\0'); /* Null-terminate for C part */ @@ -649,6 +659,7 @@ check_rspamc_command(const char *cmd) -> std::optional {"report", RSPAMC_COMMAND_SYMBOLS}, {"learn_spam", RSPAMC_COMMAND_LEARN_SPAM}, {"learn_ham", RSPAMC_COMMAND_LEARN_HAM}, + {"learn_class", RSPAMC_COMMAND_LEARN_CLASS}, {"fuzzy_add", RSPAMC_COMMAND_FUZZY_ADD}, {"fuzzy_del", RSPAMC_COMMAND_FUZZY_DEL}, {"fuzzy_delhash", RSPAMC_COMMAND_FUZZY_DELHASH}, @@ -659,10 +670,33 @@ check_rspamc_command(const char *cmd) -> std::optional }); std::string cmd_lc = rspamd_string_tolower(cmd); + + // Handle learn_class:classname syntax + if (cmd_lc.find("learn_class:") == 0) { + auto colon_pos = cmd_lc.find(':'); + if (colon_pos != std::string::npos && colon_pos + 1 < cmd_lc.length()) { + auto class_name = cmd_lc.substr(colon_pos + 1); + // Store class name globally for later use + learn_class_name = g_strdup(class_name.c_str()); + // Return the learn_class command + auto elt_it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&](const auto &item) { + return item.cmd == RSPAMC_COMMAND_LEARN_CLASS; + }); + if (elt_it != std::end(rspamc_commands)) { + return *elt_it; + } + } + return std::nullopt; + } + auto ct = rspamd::find_map(str_map, std::string_view{cmd_lc}); + if (!ct.has_value()) { + return std::nullopt; + } + auto elt_it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&](const auto &item) { - return item.cmd == ct; + return item.cmd == ct.value(); }); if (elt_it != std::end(rspamc_commands)) { @@ -799,6 +833,10 @@ add_options(GQueue *opts) add_client_header(opts, "Classifier", classifier); } + if (learn_class_name) { + add_client_header(opts, "Class", learn_class_name); + } + if (weight != 0) { auto nstr = fmt::format("{}", weight); add_client_header(opts, "Weight", nstr.c_str()); @@ -1918,7 +1956,7 @@ rspamc_client_cb(struct rspamd_client_connection *conn, if (raw_body) { /* We can also output the resulting json */ - rspamc_print(out, "{}\n", std::string_view{raw_body, (std::size_t)(rawlen - bodylen)}); + rspamc_print(out, "{}\n", std::string_view{raw_body, (std::size_t) (rawlen - bodylen)}); } } } @@ -1950,7 +1988,7 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd, p = strrchr(connect_str, ']'); if (p != nullptr) { - hostbuf.assign(connect_str + 1, (std::size_t)(p - connect_str - 1)); + hostbuf.assign(connect_str + 1, (std::size_t) (p - connect_str - 1)); p++; } else { @@ -1965,7 +2003,7 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd, if (hostbuf.empty()) { if (p != nullptr) { - hostbuf.assign(connect_str, (std::size_t)(p - connect_str)); + hostbuf.assign(connect_str, (std::size_t) (p - connect_str)); } else { hostbuf.assign(connect_str); diff --git a/src/controller.c b/src/controller.c index 0550ba6b86..6e0e4cac1e 100644 --- a/src/controller.c +++ b/src/controller.c @@ -53,6 +53,7 @@ #define PATH_HISTORY_RESET "/historyreset" #define PATH_LEARN_SPAM "/learnspam" #define PATH_LEARN_HAM "/learnham" +#define PATH_LEARN_CLASS "/learnclass" #define PATH_METRICS "/metrics" #define PATH_READY "/ready" #define PATH_SAVE_ACTIONS "/saveactions" @@ -2126,6 +2127,7 @@ rspamd_controller_handle_learn_common( struct rspamd_controller_worker_ctx *ctx; struct rspamd_task *task; const rspamd_ftok_t *cl_header; + const char *class_name; ctx = session->ctx; @@ -2167,7 +2169,9 @@ rspamd_controller_handle_learn_common( goto end; } - rspamd_learn_task_spam(task, is_spam, session->classifier, NULL); + /* Use unified class-based learning approach */ + class_name = is_spam ? "spam" : "ham"; + rspamd_task_set_autolearn_class(task, class_name); if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) { msg_warn_session("<%s> message cannot be processed", @@ -2211,6 +2215,96 @@ rspamd_controller_handle_learnham( return rspamd_controller_handle_learn_common(conn_ent, msg, FALSE); } +/* + * Learn class command handler: + * request: /learnclass + * headers: Password, Class + * input: plaintext data + * reply: json {"success":true} or {"error":"error message"} + */ +static int +rspamd_controller_handle_learnclass( + struct rspamd_http_connection_entry *conn_ent, + struct rspamd_http_message *msg) +{ + struct rspamd_controller_session *session = conn_ent->ud; + struct rspamd_controller_worker_ctx *ctx; + struct rspamd_task *task; + const rspamd_ftok_t *cl_header, *class_header; + char *class_name = NULL; + + ctx = session->ctx; + + if (!rspamd_controller_check_password(conn_ent, session, msg, TRUE)) { + return 0; + } + + if (rspamd_http_message_get_body(msg, NULL) == NULL) { + msg_err_session("got zero length body, cannot continue"); + rspamd_controller_send_error(conn_ent, + 400, + "Empty body is not permitted"); + return 0; + } + + class_header = rspamd_http_message_find_header(msg, "Class"); + if (!class_header) { + msg_err_session("missing Class header for multiclass learning"); + rspamd_controller_send_error(conn_ent, + 400, + "Class header is required for multiclass learning"); + return 0; + } + + task = rspamd_task_new(session->ctx->worker, session->cfg, session->pool, + session->ctx->lang_det, ctx->event_loop, FALSE); + + task->resolver = ctx->resolver; + task->s = rspamd_session_create(session->pool, + rspamd_controller_learn_fin_task, + NULL, + (event_finalizer_t) rspamd_task_free, + task); + task->fin_arg = conn_ent; + task->http_conn = rspamd_http_connection_ref(conn_ent->conn); + task->sock = -1; + session->task = task; + + cl_header = rspamd_http_message_find_header(msg, "classifier"); + if (cl_header) { + session->classifier = rspamd_mempool_ftokdup(session->pool, cl_header); + } + else { + session->classifier = NULL; + } + + if (!rspamd_task_load_message(task, msg, msg->body_buf.begin, msg->body_buf.len)) { + goto end; + } + + /* Set multiclass learning flag and store class name */ + class_name = rspamd_mempool_ftokdup(task->task_pool, class_header); + rspamd_task_set_autolearn_class(task, class_name); + + if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) { + msg_warn_session("<%s> message cannot be processed", + MESSAGE_FIELD_CHECK(task, message_id)); + goto end; + } + +end: + /* Set session spam flag for logging compatibility */ + if (class_name) { + session->is_spam = (strcmp(class_name, "spam") == 0); + } + else { + session->is_spam = FALSE; + } + rspamd_session_pending(task->s); + + return 0; +} + /* * Scan command handler: * request: /scan @@ -3292,7 +3386,7 @@ rspamd_controller_handle_unknown(struct rspamd_http_connection_entry *conn_ent, rspamd_http_message_add_header(rep, "Access-Control-Allow-Methods", "POST, GET, OPTIONS"); rspamd_http_message_add_header(rep, "Access-Control-Allow-Headers", - "Classifier,Content-Type,Password,Map,Weight,Flag,Hash"); + "Classifier,Class,Content-Type,Password,Map,Weight,Flag,Hash"); rspamd_http_connection_reset(conn_ent->conn); rspamd_http_router_insert_headers(conn_ent->rt, rep); rspamd_http_connection_write_message(conn_ent->conn, @@ -3456,7 +3550,7 @@ rspamd_controller_handle_lua_plugin(struct rspamd_http_connection_entry *conn_en */ static int rspamd_controller_handle_bayes_classifiers(struct rspamd_http_connection_entry *conn_ent, - struct rspamd_http_message *msg) + struct rspamd_http_message *msg) { struct rspamd_controller_session *session = conn_ent->ud; struct rspamd_controller_worker_ctx *ctx = session->ctx; @@ -4048,6 +4142,9 @@ start_controller_worker(struct rspamd_worker *worker) rspamd_http_router_add_path(ctx->http, PATH_LEARN_HAM, rspamd_controller_handle_learnham); + rspamd_http_router_add_path(ctx->http, + PATH_LEARN_CLASS, + rspamd_controller_handle_learnclass); rspamd_http_router_add_path(ctx->http, PATH_METRICS, rspamd_controller_handle_metrics); diff --git a/src/libserver/cfg_file.h b/src/libserver/cfg_file.h index cd2ab43141..5aaaece355 100644 --- a/src/libserver/cfg_file.h +++ b/src/libserver/cfg_file.h @@ -140,6 +140,7 @@ struct rspamd_statfile_config { char *label; /**< label of this statfile */ ucl_object_t *opts; /**< other options */ char *class_name; /**< class name for multi-class classification */ + unsigned int class_index; /**< class index for O(1) lookup during classification */ gboolean is_spam; /**< DEPRECATED: spam flag - use class_name instead */ struct rspamd_classifier_config *clcf; /**< parent pointer of classifier configuration */ gpointer data; /**< opaque data */ diff --git a/src/libserver/cfg_rcl.cxx b/src/libserver/cfg_rcl.cxx index 3f0a9606a2..5afb467452 100644 --- a/src/libserver/cfg_rcl.cxx +++ b/src/libserver/cfg_rcl.cxx @@ -1439,6 +1439,34 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool, cfg->classifiers = g_list_prepend(cfg->classifiers, ccf); + /* Populate class_names array from statfiles */ + if (ccf->statfiles) { + GList *cur = ccf->statfiles; + ccf->class_names = g_ptr_array_new(); + + while (cur) { + struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name) { + /* Check if class already exists */ + bool found = false; + for (unsigned int i = 0; i < ccf->class_names->len; i++) { + if (strcmp((char *) g_ptr_array_index(ccf->class_names, i), stcf->class_name) == 0) { + stcf->class_index = i; /* Store the index for O(1) lookup */ + found = true; + break; + } + } + + if (!found) { + /* Add new class */ + stcf->class_index = ccf->class_names->len; + g_ptr_array_add(ccf->class_names, g_strdup(stcf->class_name)); + } + } + cur = g_list_next(cur); + } + } + return TRUE; } diff --git a/src/libserver/task.c b/src/libserver/task.c index e043582846..f655ab11b2 100644 --- a/src/libserver/task.c +++ b/src/libserver/task.c @@ -942,15 +942,14 @@ rspamd_learn_task_spam(struct rspamd_task *task, const char *classifier, GError **err) { + /* Use unified class-based approach internally */ + const char *class_name = is_spam ? "spam" : "ham"; + /* Disable learn auto flag to avoid bad learn codes */ task->flags &= ~RSPAMD_TASK_FLAG_LEARN_AUTO; - if (is_spam) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; - } - else { - task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; - } + /* Use the unified class-based learning approach */ + rspamd_task_set_autolearn_class(task, class_name); task->classifier = classifier; diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index 0c663123cb..c897afb9bc 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -22,6 +22,7 @@ #include "contrib/fmt/include/fmt/base.h" #include "libutil/cxx/error.hxx" +#include #include #include @@ -147,11 +148,17 @@ public: rspamd_token_t *tok; if (!results) { + msg_debug_bayes("process_tokens: no results available for statfile id=%d", id); return false; } + msg_debug_bayes("processing tokens for statfile id=%d, results size=%uz, class=%s", + id, results->size(), stcf->class_name ? stcf->class_name : "unknown"); + for (auto [idx, val]: *results) { tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1); + msg_debug_bayes("setting tok->values[%d] = %.2f for token idx %d (class=%s)", + id, val, idx, stcf->class_name ? stcf->class_name : "unknown"); tok->values[id] = val; } @@ -646,46 +653,62 @@ rspamd_redis_runtime(struct rspamd_task *task, /* No cached result (or learn), create new one */ auto *rt = new redis_stat_runtime(ctx, task, object_expanded); - if (!learn && stcf->clcf && stcf->clcf->class_names && stcf->clcf->class_names->len > 2) { - /* - * For multi-class classification, we need to create runtimes for ALL classes - * to avoid multiple Redis calls. The actual Redis call will fetch data for all classes. - */ + /* Find the statfile ID for the main runtime */ + int main_id = _id; /* Use the passed _id parameter */ + rt->id = main_id; + rt->stcf = stcf; + + /* For classification, create runtimes for all other statfiles to avoid multiple Redis calls */ + if (!learn && stcf->clcf && stcf->clcf->statfiles) { GList *cur = stcf->clcf->statfiles; + while (cur) { auto *other_stcf = (struct rspamd_statfile_config *) cur->data; - if (other_stcf != stcf) { - const char *other_label = get_class_label(other_stcf); - - auto maybe_other_rt = redis_stat_runtime::maybe_recover_from_mempool(task, - object_expanded, other_label); - if (!maybe_other_rt) { - auto *other_rt = new redis_stat_runtime(ctx, task, object_expanded); - other_rt->save_in_mempool(other_label); - other_rt->need_redis_call = false; + const char *other_label = get_class_label(other_stcf); + + /* Find the statfile ID by searching through all statfiles */ + struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); + int other_id = -1; + for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) { + struct rspamd_statfile *st = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i); + if (st->stcf == other_stcf) { + other_id = st->id; + msg_debug_bayes("found statfile mapping: %s (class=%s) → id=%d", + st->stcf->symbol, other_label, other_id); + break; } } + + if (other_id == -1) { + msg_debug_bayes("statfile not found for class %s, skipping", other_label); + /* Skip if statfile not found */ + cur = g_list_next(cur); + continue; + } + + if (other_stcf == stcf) { + /* This is the main statfile, use the main runtime */ + rt->save_in_mempool(other_label); + msg_debug_bayes("main runtime: statfile %s (class=%s) → id=%d", + stcf->symbol, other_label, rt->id); + } + else { + /* Create additional runtime for other statfile */ + auto *other_rt = new redis_stat_runtime(ctx, task, object_expanded); + other_rt->id = other_id; + other_rt->stcf = other_stcf; + other_rt->save_in_mempool(other_label); + msg_debug_bayes("additional runtime: statfile %s (class=%s) → id=%d", + other_stcf->symbol, other_label, other_id); + } + cur = g_list_next(cur); } } - else if (!learn) { - /* - * For binary classification, create the opposite class runtime to avoid - * double call for Redis scripts (backward compatibility). - */ - const char *opposite_label = stcf->is_spam ? "H" : "S"; - auto maybe_opposite_rt = redis_stat_runtime::maybe_recover_from_mempool(task, - object_expanded, opposite_label); - - if (!maybe_opposite_rt) { - auto *opposite_rt = new redis_stat_runtime(ctx, task, object_expanded); - opposite_rt->save_in_mempool(opposite_label); - opposite_rt->need_redis_call = false; - } + else { + rt->save_in_mempool(class_label); } - rt->save_in_mempool(class_label); - return rt; } @@ -859,7 +882,6 @@ rspamd_redis_classified(lua_State *L) if (rt == nullptr) { msg_err_task("internal error: cannot find runtime for cookie %s", cookie); - return 0; } @@ -874,135 +896,131 @@ rspamd_redis_classified(lua_State *L) return 0; } - /* Check the array length to determine format: - * - Length 4: binary format [learned_ham, learned_spam, ham_tokens, spam_tokens] - * - Length 2: multi-class format [learned_counts_table, outputs_table] - */ + /* Redis returns [learned_counts_array, token_results_array] + * Both ordered the same way as statfiles in classifier */ size_t result_len = rspamd_lua_table_size(L, 3); msg_debug_bayes("Redis result array length: %uz", result_len); - auto filler_func = [](redis_stat_runtime *rt, lua_State *L, unsigned learned, int tokens_pos) { - rt->learned = learned; - redis_stat_runtime::result_type *res; + if (result_len != 2) { + msg_err_task("internal error: expected 2-element result from Redis script, got %uz", result_len); + rt->err = rspamd::util::error("invalid Redis script result format", 500); + return 0; + } - res = new redis_stat_runtime::result_type(); + /* Get learned_counts_array and token_results_array */ + lua_rawgeti(L, 3, 1); /* learned_counts -> position 4 */ + lua_rawgeti(L, 3, 2); /* token_results -> position 5 */ - for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) { - lua_rawgeti(L, -1, 1); - auto idx = lua_tointeger(L, -1); - lua_pop(L, 1); + /* Process results for all statfiles in order using class_index (O(N) instead of O(N²)) */ + if (rt->stcf->clcf && rt->stcf->clcf->statfiles) { + GList *cur = rt->stcf->clcf->statfiles; + int redis_idx = 1; /* Redis result array index (1-based) */ - lua_rawgeti(L, -1, 2); - auto value = lua_tonumber(L, -1); - lua_pop(L, 1); + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; - res->emplace_back(idx, value); - } + /* Direct statfile lookup using global statfiles array */ + struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); + struct rspamd_statfile *st = nullptr; + + /* Find statfile by config pointer (still O(N) but unavoidable) */ + for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) { + struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i); + if (candidate->stcf == stcf) { + st = candidate; + break; + } + } + + if (!st) { + msg_debug_bayes("statfile not found for config %s, skipping", stcf->symbol); + cur = g_list_next(cur); + redis_idx++; + continue; + } - rt->set_results(res); - }; + /* Get or create runtime for this statfile */ + auto *statfile_rt = rt; /* Use current runtime for first statfile */ + if (stcf != rt->stcf) { + const char *class_label = get_class_label(stcf); + auto maybe_rt = redis_stat_runtime::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + statfile_rt = maybe_rt.value(); + } + else { + msg_debug_bayes("runtime not found for class %s, skipping", class_label); + cur = g_list_next(cur); + redis_idx++; + continue; + } + } - bool is_binary_format = (result_len == 4); + /* Ensure correct statfile ID assignment */ + statfile_rt->id = st->id; + + /* Process token results for this statfile (Redis array index redis_idx) */ + lua_rawgeti(L, 5, redis_idx); /* Get token_results[redis_idx] */ + if (lua_istable(L, -1)) { + /* Parse token results into statfile runtime */ + auto *res = new std::vector>(); + + lua_pushnil(L); /* First key for iteration */ + while (lua_next(L, -2) != 0) { + if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) { + lua_rawgeti(L, -1, 1); /* token_index */ + lua_rawgeti(L, -2, 2); /* token_count */ + + if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) { + int token_idx = lua_tointeger(L, -2); + float token_count = lua_tonumber(L, -1); + res->emplace_back(token_idx, token_count); + } + + lua_pop(L, 2); /* Pop token_index and token_count */ + } + lua_pop(L, 1); /* Pop value, keep key for next iteration */ + } - if (is_binary_format) { - /* Binary format: [learned_ham, learned_spam, ham_tokens, spam_tokens] */ + statfile_rt->set_results(res); + } + lua_pop(L, 1); /* Pop token_results[redis_idx] */ - /* Find the opposite runtime for binary classification compatibility */ - const char *opposite_label; - if (rt->stcf->class_name) { - /* Multi-class: find a different class (simplified for now) */ - opposite_label = strcmp(get_class_label(rt->stcf), "S") == 0 ? "H" : "S"; + cur = g_list_next(cur); + redis_idx++; } - else { - /* Binary: use opposite spam/ham */ - opposite_label = rt->stcf->is_spam ? "H" : "S"; - } - auto opposite_rt_maybe = redis_stat_runtime::maybe_recover_from_mempool(task, - rt->redis_object_expanded, - opposite_label); + } - if (!opposite_rt_maybe) { - msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie); - return 0; - } + /* Clean up stack */ + lua_pop(L, 2); /* Pop learned_counts and token_results */ - /* Extract values from the result table at position 3 */ - lua_rawgeti(L, 3, 1); /* learned_ham -> position 4 */ - lua_rawgeti(L, 3, 2); /* learned_spam -> position 5 */ - lua_rawgeti(L, 3, 3); /* ham_tokens -> position 6 */ - lua_rawgeti(L, 3, 4); /* spam_tokens -> position 7 */ + /* Process tokens for all runtimes */ + g_assert(rt->tokens != nullptr); - unsigned learned_ham = lua_tointeger(L, 4); - unsigned learned_spam = lua_tointeger(L, 5); + /* Process tokens for all statfiles */ + if (rt->stcf->clcf && rt->stcf->clcf->statfiles) { + GList *cur = rt->stcf->clcf->statfiles; - if (rt->stcf->is_spam || (rt->stcf->class_name && strcmp(get_class_label(rt->stcf), "S") == 0)) { - /* Current runtime is spam, use spam data */ - filler_func(rt, L, learned_spam, 7); /* spam_tokens at position 7 */ - filler_func(opposite_rt_maybe.value(), L, learned_ham, 6); /* ham_tokens at position 6 */ - } - else { - /* Current runtime is ham, use ham data */ - filler_func(rt, L, learned_ham, 6); /* ham_tokens at position 6 */ - filler_func(opposite_rt_maybe.value(), L, learned_spam, 7); /* spam_tokens at position 7 */ - } + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + const char *class_label = get_class_label(stcf); - /* Clean up the stack - pop the 4 extracted values */ - lua_pop(L, 4); + auto maybe_rt = redis_stat_runtime::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + auto *statfile_rt = maybe_rt.value(); + statfile_rt->process_tokens(rt->tokens); + } - /* Process all tokens */ - g_assert(rt->tokens != nullptr); - rt->process_tokens(rt->tokens); - opposite_rt_maybe.value()->process_tokens(rt->tokens); + cur = g_list_next(cur); + } } else { - /* Multi-class format: [learned_counts_table, outputs_table] */ - - /* Get learned counts table (index 1) and outputs table (index 2) */ - lua_rawgeti(L, 3, 1); /* learned_counts */ - lua_rawgeti(L, 3, 2); /* outputs */ - - /* Iterate through all class labels to fill all runtimes */ - if (rt->stcf->clcf && rt->stcf->clcf->class_labels) { - GHashTableIter iter; - gpointer key, value; - g_hash_table_iter_init(&iter, rt->stcf->clcf->class_labels); - - while (g_hash_table_iter_next(&iter, &key, &value)) { - const char *class_label = (const char *) value; - - /* Find runtime for this class */ - auto class_rt_maybe = redis_stat_runtime::maybe_recover_from_mempool(task, - rt->redis_object_expanded, - class_label); - - if (class_rt_maybe) { - auto *class_rt = class_rt_maybe.value(); - - /* Get learned count for this class */ - lua_pushstring(L, class_label); - lua_gettable(L, -3); /* learned_counts[class_label] */ - unsigned learned = lua_tointeger(L, -1); - lua_pop(L, 1); - - /* Get outputs for this class */ - lua_pushstring(L, class_label); - lua_gettable(L, -2); /* outputs[class_label] */ - int outputs_pos = lua_gettop(L); - - filler_func(class_rt, L, learned, outputs_pos); - lua_pop(L, 1); - } - } - } - - lua_pop(L, 2); /* Pop learned_counts and outputs tables */ - - /* Process tokens for all runtimes */ - g_assert(rt->tokens != nullptr); + /* Fallback: just process the main runtime */ rt->process_tokens(rt->tokens); } - - /* Tokens processed - no need to set flags in multi-class approach */ } else { /* Error message is on index 3 */ diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c index 836c27c88d..4d070ee202 100644 --- a/src/libstat/classifiers/bayes.c +++ b/src/libstat/classifiers/bayes.c @@ -302,15 +302,16 @@ bayes_classify_token_multiclass(struct rspamd_classifier *ctx, val = tok->values[id]; if (val > 0) { - /* Find which class this statfile belongs to */ - for (j = 0; j < cl->num_classes; j++) { - if (st->stcf->class_name && - strcmp(st->stcf->class_name, cl->class_names[j]) == 0) { - class_counts[j] += val; - total_count += val; - cl->total_hits += val; - break; - } + /* Direct O(1) class index lookup instead of O(N) string comparison */ + if (st->stcf->class_name && st->stcf->class_index < cl->num_classes) { + unsigned int class_idx = st->stcf->class_index; + class_counts[class_idx] += val; + total_count += val; + cl->total_hits += val; + } + else { + msg_debug_bayes("invalid class_index %ud >= %ud for statfile %s", + st->stcf->class_index, cl->num_classes, st->stcf->symbol); } } } @@ -348,7 +349,7 @@ bayes_classify_token_multiclass(struct rspamd_classifier *ctx, if (tok->t1 && tok->t2) { msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %.3f, total_count: %ud, " - "processed for %u classes", + "processed for %ud classes", token_type, tok->data, (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, @@ -385,13 +386,27 @@ bayes_classify_multiclass(struct rspamd_classifier *ctx, cl.cfg = ctx->cfg; /* Get class information from classifier config */ - if (!ctx->cfg->class_names || ctx->cfg->class_names->len < 2) { - msg_debug_bayes("insufficient classes for multiclass classification"); + if (!ctx->cfg->class_names) { + msg_debug_bayes("no class_names array in classifier config"); + return TRUE; /* Fall back to binary mode */ + } + if (ctx->cfg->class_names->len < 2) { + msg_debug_bayes("insufficient classes: %ud < 2", (unsigned int) ctx->cfg->class_names->len); + return TRUE; /* Fall back to binary mode */ + } + if (!ctx->cfg->class_names->pdata) { + msg_debug_bayes("class_names->pdata is NULL"); return TRUE; /* Fall back to binary mode */ } cl.num_classes = ctx->cfg->class_names->len; cl.class_names = (char **) ctx->cfg->class_names->pdata; + + /* Debug: verify class names are accessible */ + msg_debug_bayes("multiclass setup: ctx->cfg->class_names=%p, len=%ud, pdata=%p", + ctx->cfg->class_names, (unsigned int) ctx->cfg->class_names->len, ctx->cfg->class_names->pdata); + msg_debug_bayes("multiclass setup: cl.num_classes=%ud, cl.class_names=%p", + cl.num_classes, cl.class_names); cl.class_log_probs = g_alloca(cl.num_classes * sizeof(double)); cl.class_learns = g_alloca(cl.num_classes * sizeof(uint64_t)); @@ -459,6 +474,22 @@ bayes_classify_multiclass(struct rspamd_classifier *ctx, } if (cl.processed_tokens == 0) { + /* Debug: check why no tokens were processed */ + msg_debug_bayes("examining token values for debugging:"); + for (i = 0; i < MIN(tokens->len, 10); i++) { /* Check first 10 tokens */ + tok = g_ptr_array_index(tokens, i); + for (j = 0; j < ctx->statfiles_ids->len; j++) { + id = g_array_index(ctx->statfiles_ids, int, j); + if (tok->values[id] > 0) { + struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id); + msg_debug_bayes("token %ud: values[%d] = %.2f (class=%s, symbol=%s)", + i, id, tok->values[id], + st->stcf->class_name ? st->stcf->class_name : "unknown", + st->stcf->symbol); + } + } + } + msg_info_bayes("no tokens found in bayes database " "(%ud total tokens, %ud text tokens), ignore stats", tokens->len, text_tokens);