]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Further updates
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 22 Jul 2025 21:30:03 +0000 (22:30 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 22 Jul 2025 21:30:03 +0000 (22:30 +0100)
lualib/redis_scripts/bayes_classify.lua
src/client/rspamc.cxx
src/controller.c
src/libserver/cfg_file.h
src/libserver/cfg_rcl.cxx
src/libserver/task.c
src/libstat/backends/redis_backend.cxx
src/libstat/classifiers/bayes.c

index e07b9a795690a3ddfcc60405b636da0d3003e407..923adcc5ad69c83b62f81fc044cfe98cd27bf2e4 100644 (file)
@@ -1,51 +1,47 @@
 -- Lua script to perform bayes classification (multi-class)
 -- This script accepts the following parameters:
 -- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - class labels: either table of all class labels (multi-class) or single string (binary)
+-- key2 - class labels: table of all class labels as "TABLE:label1,label2,..."
 -- key3 - set of tokens encoded in messagepack array of strings
 
 local prefix = KEYS[1]
 local class_labels_arg = KEYS[2]
 local input_tokens = cmsgpack.unpack(KEYS[3])
 
--- Determine if this is multi-class (table) or binary (string)
+-- Parse class labels (always expect TABLE: format)
 local class_labels = {}
-
--- Check if this is a table serialized as "TABLE:label1,label2,..."
 if string.match(class_labels_arg, "^TABLE:") then
   local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:" prefix
-  -- Split by comma
   for label in string.gmatch(labels_str, "([^,]+)") do
     table.insert(class_labels, label)
   end
 else
-  -- Binary compatibility: handle old boolean or single string format
-  if class_labels_arg == "true" then
-    class_labels = { "S" }              -- spam
-  elseif class_labels_arg == "false" then
-    class_labels = { "H" }              -- ham
-  else
-    class_labels = { class_labels_arg } -- single class label
-  end
+  -- Legacy single class - convert to array
+  class_labels = { class_labels_arg }
 end
 
--- Get learned counts for all classes
+-- Get learned counts for all classes (ordered)
 local learned_counts = {}
 for _, label in ipairs(class_labels) do
   local key = 'learns_' .. string.lower(label)
-  -- Also try legacy keys for backward compatibility
+  -- Handle legacy keys for backward compatibility
   if label == 'H' then
     key = 'learns_ham'
   elseif label == 'S' then
     key = 'learns_spam'
   end
-  learned_counts[label] = tonumber(redis.call('HGET', prefix, key)) or 0
+  table.insert(learned_counts, tonumber(redis.call('HGET', prefix, key)) or 0)
+end
+
+-- Get token data for all classes (ordered)
+local token_results = {}
+for i, label in ipairs(class_labels) do
+  token_results[i] = {}
 end
 
--- Get token data for all classes (only if we have learns for any class)
-local outputs = {}
+-- Check if we have any learning data
 local has_learns = false
-for _, count in pairs(learned_counts) do
+for _, count in ipairs(learned_counts) do
   if count > 0 then
     has_learns = true
     break
@@ -53,11 +49,6 @@ for _, count in pairs(learned_counts) do
 end
 
 if has_learns then
-  -- Initialize outputs for each class
-  for _, label in ipairs(class_labels) do
-    outputs[label] = {}
-  end
-
   -- Process each token
   for i, token in ipairs(input_tokens) do
     local token_data = redis.call('HMGET', token, unpack(class_labels))
@@ -65,32 +56,13 @@ if has_learns then
     if token_data then
       for j, label in ipairs(class_labels) do
         local count = token_data[j]
-        if count then
-          table.insert(outputs[label], { i, tonumber(count) })
+        if count and tonumber(count) > 0 then
+          table.insert(token_results[j], { i, tonumber(count) })
         end
       end
     end
   end
 end
 
--- Format output for backward compatibility
-if #class_labels == 2 and class_labels[1] == 'H' and class_labels[2] == 'S' then
-  -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
-  return {
-    learned_counts['H'] or 0,
-    learned_counts['S'] or 0,
-    outputs['H'] or {},
-    outputs['S'] or {}
-  }
-elseif #class_labels == 2 and class_labels[1] == 'S' and class_labels[2] == 'H' then
-  -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
-  return {
-    learned_counts['H'] or 0,
-    learned_counts['S'] or 0,
-    outputs['H'] or {},
-    outputs['S'] or {}
-  }
-else
-  -- Multi-class format: [learned_counts_table, outputs_table]
-  return { learned_counts, outputs }
-end
+-- Always return ordered arrays: [learned_counts_array, token_results_array]
+return { learned_counts, token_results }
index 40435987735c3b7aa4aa185427263ce2e5ca9090..af88acb3377cd7f1630dc0ca82840da52805b3e2 100644 (file)
@@ -59,6 +59,7 @@ static const char *user = nullptr;
 static const char *helo = nullptr;
 static const char *hostname = nullptr;
 static const char *classifier = nullptr;
+static const char *learn_class_name = nullptr;
 static const char *local_addr = nullptr;
 static const char *execute = nullptr;
 static const char *sort = nullptr;
@@ -198,6 +199,7 @@ enum rspamc_command_type {
        RSPAMC_COMMAND_SYMBOLS,
        RSPAMC_COMMAND_LEARN_SPAM,
        RSPAMC_COMMAND_LEARN_HAM,
+       RSPAMC_COMMAND_LEARN_CLASS,
        RSPAMC_COMMAND_FUZZY_ADD,
        RSPAMC_COMMAND_FUZZY_DEL,
        RSPAMC_COMMAND_FUZZY_DELHASH,
@@ -249,6 +251,15 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
                .is_privileged = TRUE,
                .need_input = TRUE,
                .command_output_func = nullptr},
+       rspamc_command{
+               .cmd = RSPAMC_COMMAND_LEARN_CLASS,
+               .name = "learn_class",
+               .path = "learnclass",
+               .description = "learn message as class",
+               .is_controller = TRUE,
+               .is_privileged = TRUE,
+               .need_input = TRUE,
+               .command_output_func = nullptr},
        rspamc_command{
                .cmd = RSPAMC_COMMAND_FUZZY_ADD,
                .name = "fuzzy_add",
@@ -527,8 +538,7 @@ rspamc_password_callback(const char *option_name,
                                auto *map = (char *) locked_mmap.value().get_map();
                                value_view = std::string_view{map, locked_mmap->get_size()};
                                auto right = value_view.end() - 1;
-                               for (; right > value_view.cbegin() && g_ascii_isspace(*right); --right)
-                                       ;
+                               for (; right > value_view.cbegin() && g_ascii_isspace(*right); --right);
                                std::string_view str{value_view.begin(), static_cast<size_t>(right - value_view.begin()) + 1};
                                processed_passwd.assign(std::begin(str), std::end(str));
                                processed_passwd.push_back('\0'); /* Null-terminate for C part */
@@ -649,6 +659,7 @@ check_rspamc_command(const char *cmd) -> std::optional<rspamc_command>
                {"report", RSPAMC_COMMAND_SYMBOLS},
                {"learn_spam", RSPAMC_COMMAND_LEARN_SPAM},
                {"learn_ham", RSPAMC_COMMAND_LEARN_HAM},
+               {"learn_class", RSPAMC_COMMAND_LEARN_CLASS},
                {"fuzzy_add", RSPAMC_COMMAND_FUZZY_ADD},
                {"fuzzy_del", RSPAMC_COMMAND_FUZZY_DEL},
                {"fuzzy_delhash", RSPAMC_COMMAND_FUZZY_DELHASH},
@@ -659,10 +670,33 @@ check_rspamc_command(const char *cmd) -> std::optional<rspamc_command>
        });
 
        std::string cmd_lc = rspamd_string_tolower(cmd);
+
+       // Handle learn_class:classname syntax
+       if (cmd_lc.find("learn_class:") == 0) {
+               auto colon_pos = cmd_lc.find(':');
+               if (colon_pos != std::string::npos && colon_pos + 1 < cmd_lc.length()) {
+                       auto class_name = cmd_lc.substr(colon_pos + 1);
+                       // Store class name globally for later use
+                       learn_class_name = g_strdup(class_name.c_str());
+                       // Return the learn_class command
+                       auto elt_it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&](const auto &item) {
+                               return item.cmd == RSPAMC_COMMAND_LEARN_CLASS;
+                       });
+                       if (elt_it != std::end(rspamc_commands)) {
+                               return *elt_it;
+                       }
+               }
+               return std::nullopt;
+       }
+
        auto ct = rspamd::find_map(str_map, std::string_view{cmd_lc});
 
+       if (!ct.has_value()) {
+               return std::nullopt;
+       }
+
        auto elt_it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&](const auto &item) {
-               return item.cmd == ct;
+               return item.cmd == ct.value();
        });
 
        if (elt_it != std::end(rspamc_commands)) {
@@ -799,6 +833,10 @@ add_options(GQueue *opts)
                add_client_header(opts, "Classifier", classifier);
        }
 
+       if (learn_class_name) {
+               add_client_header(opts, "Class", learn_class_name);
+       }
+
        if (weight != 0) {
                auto nstr = fmt::format("{}", weight);
                add_client_header(opts, "Weight", nstr.c_str());
@@ -1918,7 +1956,7 @@ rspamc_client_cb(struct rspamd_client_connection *conn,
 
                                        if (raw_body) {
                                                /* We can also output the resulting json */
-                                               rspamc_print(out, "{}\n", std::string_view{raw_body, (std::size_t)(rawlen - bodylen)});
+                                               rspamc_print(out, "{}\n", std::string_view{raw_body, (std::size_t) (rawlen - bodylen)});
                                        }
                                }
                        }
@@ -1950,7 +1988,7 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd,
                p = strrchr(connect_str, ']');
 
                if (p != nullptr) {
-                       hostbuf.assign(connect_str + 1, (std::size_t)(p - connect_str - 1));
+                       hostbuf.assign(connect_str + 1, (std::size_t) (p - connect_str - 1));
                        p++;
                }
                else {
@@ -1965,7 +2003,7 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd,
 
        if (hostbuf.empty()) {
                if (p != nullptr) {
-                       hostbuf.assign(connect_str, (std::size_t)(p - connect_str));
+                       hostbuf.assign(connect_str, (std::size_t) (p - connect_str));
                }
                else {
                        hostbuf.assign(connect_str);
index 0550ba6b866f201fef82a1468b81ca14e7d082b1..6e0e4cac1eebc23d65dafe2f32b17d38cf90f03a 100644 (file)
@@ -53,6 +53,7 @@
 #define PATH_HISTORY_RESET "/historyreset"
 #define PATH_LEARN_SPAM "/learnspam"
 #define PATH_LEARN_HAM "/learnham"
+#define PATH_LEARN_CLASS "/learnclass"
 #define PATH_METRICS "/metrics"
 #define PATH_READY "/ready"
 #define PATH_SAVE_ACTIONS "/saveactions"
@@ -2126,6 +2127,7 @@ rspamd_controller_handle_learn_common(
        struct rspamd_controller_worker_ctx *ctx;
        struct rspamd_task *task;
        const rspamd_ftok_t *cl_header;
+       const char *class_name;
 
        ctx = session->ctx;
 
@@ -2167,7 +2169,9 @@ rspamd_controller_handle_learn_common(
                goto end;
        }
 
-       rspamd_learn_task_spam(task, is_spam, session->classifier, NULL);
+       /* Use unified class-based learning approach */
+       class_name = is_spam ? "spam" : "ham";
+       rspamd_task_set_autolearn_class(task, class_name);
 
        if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) {
                msg_warn_session("<%s> message cannot be processed",
@@ -2211,6 +2215,96 @@ rspamd_controller_handle_learnham(
        return rspamd_controller_handle_learn_common(conn_ent, msg, FALSE);
 }
 
+/*
+ * Learn class command handler:
+ * request: /learnclass
+ * headers: Password, Class
+ * input: plaintext data
+ * reply: json {"success":true} or {"error":"error message"}
+ */
+static int
+rspamd_controller_handle_learnclass(
+       struct rspamd_http_connection_entry *conn_ent,
+       struct rspamd_http_message *msg)
+{
+       struct rspamd_controller_session *session = conn_ent->ud;
+       struct rspamd_controller_worker_ctx *ctx;
+       struct rspamd_task *task;
+       const rspamd_ftok_t *cl_header, *class_header;
+       char *class_name = NULL;
+
+       ctx = session->ctx;
+
+       if (!rspamd_controller_check_password(conn_ent, session, msg, TRUE)) {
+               return 0;
+       }
+
+       if (rspamd_http_message_get_body(msg, NULL) == NULL) {
+               msg_err_session("got zero length body, cannot continue");
+               rspamd_controller_send_error(conn_ent,
+                                                                        400,
+                                                                        "Empty body is not permitted");
+               return 0;
+       }
+
+       class_header = rspamd_http_message_find_header(msg, "Class");
+       if (!class_header) {
+               msg_err_session("missing Class header for multiclass learning");
+               rspamd_controller_send_error(conn_ent,
+                                                                        400,
+                                                                        "Class header is required for multiclass learning");
+               return 0;
+       }
+
+       task = rspamd_task_new(session->ctx->worker, session->cfg, session->pool,
+                                                  session->ctx->lang_det, ctx->event_loop, FALSE);
+
+       task->resolver = ctx->resolver;
+       task->s = rspamd_session_create(session->pool,
+                                                                       rspamd_controller_learn_fin_task,
+                                                                       NULL,
+                                                                       (event_finalizer_t) rspamd_task_free,
+                                                                       task);
+       task->fin_arg = conn_ent;
+       task->http_conn = rspamd_http_connection_ref(conn_ent->conn);
+       task->sock = -1;
+       session->task = task;
+
+       cl_header = rspamd_http_message_find_header(msg, "classifier");
+       if (cl_header) {
+               session->classifier = rspamd_mempool_ftokdup(session->pool, cl_header);
+       }
+       else {
+               session->classifier = NULL;
+       }
+
+       if (!rspamd_task_load_message(task, msg, msg->body_buf.begin, msg->body_buf.len)) {
+               goto end;
+       }
+
+       /* Set multiclass learning flag and store class name */
+       class_name = rspamd_mempool_ftokdup(task->task_pool, class_header);
+       rspamd_task_set_autolearn_class(task, class_name);
+
+       if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) {
+               msg_warn_session("<%s> message cannot be processed",
+                                                MESSAGE_FIELD_CHECK(task, message_id));
+               goto end;
+       }
+
+end:
+       /* Set session spam flag for logging compatibility */
+       if (class_name) {
+               session->is_spam = (strcmp(class_name, "spam") == 0);
+       }
+       else {
+               session->is_spam = FALSE;
+       }
+       rspamd_session_pending(task->s);
+
+       return 0;
+}
+
 /*
  * Scan command handler:
  * request: /scan
@@ -3292,7 +3386,7 @@ rspamd_controller_handle_unknown(struct rspamd_http_connection_entry *conn_ent,
                rspamd_http_message_add_header(rep, "Access-Control-Allow-Methods",
                                                                           "POST, GET, OPTIONS");
                rspamd_http_message_add_header(rep, "Access-Control-Allow-Headers",
-                                                                          "Classifier,Content-Type,Password,Map,Weight,Flag,Hash");
+                                                                          "Classifier,Class,Content-Type,Password,Map,Weight,Flag,Hash");
                rspamd_http_connection_reset(conn_ent->conn);
                rspamd_http_router_insert_headers(conn_ent->rt, rep);
                rspamd_http_connection_write_message(conn_ent->conn,
@@ -3456,7 +3550,7 @@ rspamd_controller_handle_lua_plugin(struct rspamd_http_connection_entry *conn_en
  */
 static int
 rspamd_controller_handle_bayes_classifiers(struct rspamd_http_connection_entry *conn_ent,
-                                                                                       struct rspamd_http_message *msg)
+                                                                                  struct rspamd_http_message *msg)
 {
        struct rspamd_controller_session *session = conn_ent->ud;
        struct rspamd_controller_worker_ctx *ctx = session->ctx;
@@ -4048,6 +4142,9 @@ start_controller_worker(struct rspamd_worker *worker)
        rspamd_http_router_add_path(ctx->http,
                                                                PATH_LEARN_HAM,
                                                                rspamd_controller_handle_learnham);
+       rspamd_http_router_add_path(ctx->http,
+                                                               PATH_LEARN_CLASS,
+                                                               rspamd_controller_handle_learnclass);
        rspamd_http_router_add_path(ctx->http,
                                                                PATH_METRICS,
                                                                rspamd_controller_handle_metrics);
index cd2ab43141061c037ddec08a91cb7a6f4f2f5b59..5aaaece3552555dd4dcb9933b26c43b6dcd56e06 100644 (file)
@@ -140,6 +140,7 @@ struct rspamd_statfile_config {
        char *label;                           /**< label of this statfile                                                              */
        ucl_object_t *opts;                    /**< other options                                                                               */
        char *class_name;                      /**< class name for multi-class classification                   */
+       unsigned int class_index;              /**< class index for O(1) lookup during classification   */
        gboolean is_spam;                      /**< DEPRECATED: spam flag - use class_name instead              */
        struct rspamd_classifier_config *clcf; /**< parent pointer of classifier configuration                  */
        gpointer data;                         /**< opaque data                                                                                 */
index 3f0a9606a24ef30a2b1a9b2223e9418f40388ae2..5afb467452bccd6675ba46cfaa5ecc900f1df42b 100644 (file)
@@ -1439,6 +1439,34 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool,
 
        cfg->classifiers = g_list_prepend(cfg->classifiers, ccf);
 
+       /* Populate class_names array from statfiles */
+       if (ccf->statfiles) {
+               GList *cur = ccf->statfiles;
+               ccf->class_names = g_ptr_array_new();
+
+               while (cur) {
+                       struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+                       if (stcf->class_name) {
+                               /* Check if class already exists */
+                               bool found = false;
+                               for (unsigned int i = 0; i < ccf->class_names->len; i++) {
+                                       if (strcmp((char *) g_ptr_array_index(ccf->class_names, i), stcf->class_name) == 0) {
+                                               stcf->class_index = i; /* Store the index for O(1) lookup */
+                                               found = true;
+                                               break;
+                                       }
+                               }
+
+                               if (!found) {
+                                       /* Add new class */
+                                       stcf->class_index = ccf->class_names->len;
+                                       g_ptr_array_add(ccf->class_names, g_strdup(stcf->class_name));
+                               }
+                       }
+                       cur = g_list_next(cur);
+               }
+       }
+
        return TRUE;
 }
 
index e0435828461a09f58da70a56f7f735ffc3558b61..f655ab11b2c11a13961aa0945bed912a2e823dbe 100644 (file)
@@ -942,15 +942,14 @@ rspamd_learn_task_spam(struct rspamd_task *task,
                                           const char *classifier,
                                           GError **err)
 {
+       /* Use unified class-based approach internally */
+       const char *class_name = is_spam ? "spam" : "ham";
+
        /* Disable learn auto flag to avoid bad learn codes */
        task->flags &= ~RSPAMD_TASK_FLAG_LEARN_AUTO;
 
-       if (is_spam) {
-               task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
-       }
-       else {
-               task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
-       }
+       /* Use the unified class-based learning approach */
+       rspamd_task_set_autolearn_class(task, class_name);
 
        task->classifier = classifier;
 
index 0c663123cbc4d3c1a6c5d1d32436eb2d23781bca..c897afb9bcd17604ea2f67e425d345970245db18 100644 (file)
@@ -22,6 +22,7 @@
 #include "contrib/fmt/include/fmt/base.h"
 
 #include "libutil/cxx/error.hxx"
+#include <map>
 
 #include <string>
 #include <cstdint>
@@ -147,11 +148,17 @@ public:
                rspamd_token_t *tok;
 
                if (!results) {
+                       msg_debug_bayes("process_tokens: no results available for statfile id=%d", id);
                        return false;
                }
 
+               msg_debug_bayes("processing tokens for statfile id=%d, results size=%uz, class=%s",
+                                               id, results->size(), stcf->class_name ? stcf->class_name : "unknown");
+
                for (auto [idx, val]: *results) {
                        tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1);
+                       msg_debug_bayes("setting tok->values[%d] = %.2f for token idx %d (class=%s)",
+                                                       id, val, idx, stcf->class_name ? stcf->class_name : "unknown");
                        tok->values[id] = val;
                }
 
@@ -646,46 +653,62 @@ rspamd_redis_runtime(struct rspamd_task *task,
        /* No cached result (or learn), create new one */
        auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
 
-       if (!learn && stcf->clcf && stcf->clcf->class_names && stcf->clcf->class_names->len > 2) {
-               /*
-                * For multi-class classification, we need to create runtimes for ALL classes
-                * to avoid multiple Redis calls. The actual Redis call will fetch data for all classes.
-                */
+       /* Find the statfile ID for the main runtime */
+       int main_id = _id; /* Use the passed _id parameter */
+       rt->id = main_id;
+       rt->stcf = stcf;
+
+       /* For classification, create runtimes for all other statfiles to avoid multiple Redis calls */
+       if (!learn && stcf->clcf && stcf->clcf->statfiles) {
                GList *cur = stcf->clcf->statfiles;
+
                while (cur) {
                        auto *other_stcf = (struct rspamd_statfile_config *) cur->data;
-                       if (other_stcf != stcf) {
-                               const char *other_label = get_class_label(other_stcf);
-
-                               auto maybe_other_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
-                                                                                                                                                                                       object_expanded, other_label);
-                               if (!maybe_other_rt) {
-                                       auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
-                                       other_rt->save_in_mempool(other_label);
-                                       other_rt->need_redis_call = false;
+                       const char *other_label = get_class_label(other_stcf);
+
+                       /* Find the statfile ID by searching through all statfiles */
+                       struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+                       int other_id = -1;
+                       for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+                               struct rspamd_statfile *st = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+                               if (st->stcf == other_stcf) {
+                                       other_id = st->id;
+                                       msg_debug_bayes("found statfile mapping: %s (class=%s) â†’ id=%d",
+                                                                       st->stcf->symbol, other_label, other_id);
+                                       break;
                                }
                        }
+
+                       if (other_id == -1) {
+                               msg_debug_bayes("statfile not found for class %s, skipping", other_label);
+                               /* Skip if statfile not found */
+                               cur = g_list_next(cur);
+                               continue;
+                       }
+
+                       if (other_stcf == stcf) {
+                               /* This is the main statfile, use the main runtime */
+                               rt->save_in_mempool(other_label);
+                               msg_debug_bayes("main runtime: statfile %s (class=%s) â†’ id=%d",
+                                                               stcf->symbol, other_label, rt->id);
+                       }
+                       else {
+                               /* Create additional runtime for other statfile */
+                               auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+                               other_rt->id = other_id;
+                               other_rt->stcf = other_stcf;
+                               other_rt->save_in_mempool(other_label);
+                               msg_debug_bayes("additional runtime: statfile %s (class=%s) â†’ id=%d",
+                                                               other_stcf->symbol, other_label, other_id);
+                       }
+
                        cur = g_list_next(cur);
                }
        }
-       else if (!learn) {
-               /*
-                * For binary classification, create the opposite class runtime to avoid
-                * double call for Redis scripts (backward compatibility).
-                */
-               const char *opposite_label = stcf->is_spam ? "H" : "S";
-               auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
-                                                                                                                                                                          object_expanded, opposite_label);
-
-               if (!maybe_opposite_rt) {
-                       auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
-                       opposite_rt->save_in_mempool(opposite_label);
-                       opposite_rt->need_redis_call = false;
-               }
+       else {
+               rt->save_in_mempool(class_label);
        }
 
-       rt->save_in_mempool(class_label);
-
        return rt;
 }
 
@@ -859,7 +882,6 @@ rspamd_redis_classified(lua_State *L)
 
        if (rt == nullptr) {
                msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
-
                return 0;
        }
 
@@ -874,135 +896,131 @@ rspamd_redis_classified(lua_State *L)
                        return 0;
                }
 
-               /* Check the array length to determine format:
-                * - Length 4: binary format [learned_ham, learned_spam, ham_tokens, spam_tokens]
-                * - Length 2: multi-class format [learned_counts_table, outputs_table]
-                */
+               /* Redis returns [learned_counts_array, token_results_array]
+                * Both ordered the same way as statfiles in classifier */
                size_t result_len = rspamd_lua_table_size(L, 3);
                msg_debug_bayes("Redis result array length: %uz", result_len);
 
-               auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
-                       rt->learned = learned;
-                       redis_stat_runtime<float>::result_type *res;
+               if (result_len != 2) {
+                       msg_err_task("internal error: expected 2-element result from Redis script, got %uz", result_len);
+                       rt->err = rspamd::util::error("invalid Redis script result format", 500);
+                       return 0;
+               }
 
-                       res = new redis_stat_runtime<float>::result_type();
+               /* Get learned_counts_array and token_results_array */
+               lua_rawgeti(L, 3, 1); /* learned_counts -> position 4 */
+               lua_rawgeti(L, 3, 2); /* token_results -> position 5 */
 
-                       for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) {
-                               lua_rawgeti(L, -1, 1);
-                               auto idx = lua_tointeger(L, -1);
-                               lua_pop(L, 1);
+               /* Process results for all statfiles in order using class_index (O(N) instead of O(N²)) */
+               if (rt->stcf->clcf && rt->stcf->clcf->statfiles) {
+                       GList *cur = rt->stcf->clcf->statfiles;
+                       int redis_idx = 1; /* Redis result array index (1-based) */
 
-                               lua_rawgeti(L, -1, 2);
-                               auto value = lua_tonumber(L, -1);
-                               lua_pop(L, 1);
+                       while (cur) {
+                               auto *stcf = (struct rspamd_statfile_config *) cur->data;
 
-                               res->emplace_back(idx, value);
-                       }
+                               /* Direct statfile lookup using global statfiles array */
+                               struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+                               struct rspamd_statfile *st = nullptr;
+
+                               /* Find statfile by config pointer (still O(N) but unavoidable) */
+                               for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+                                       struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+                                       if (candidate->stcf == stcf) {
+                                               st = candidate;
+                                               break;
+                                       }
+                               }
+
+                               if (!st) {
+                                       msg_debug_bayes("statfile not found for config %s, skipping", stcf->symbol);
+                                       cur = g_list_next(cur);
+                                       redis_idx++;
+                                       continue;
+                               }
 
-                       rt->set_results(res);
-               };
+                               /* Get or create runtime for this statfile */
+                               auto *statfile_rt = rt; /* Use current runtime for first statfile */
+                               if (stcf != rt->stcf) {
+                                       const char *class_label = get_class_label(stcf);
+                                       auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+                                                                                                                                                                                 rt->redis_object_expanded,
+                                                                                                                                                                                 class_label);
+                                       if (maybe_rt) {
+                                               statfile_rt = maybe_rt.value();
+                                       }
+                                       else {
+                                               msg_debug_bayes("runtime not found for class %s, skipping", class_label);
+                                               cur = g_list_next(cur);
+                                               redis_idx++;
+                                               continue;
+                                       }
+                               }
 
-               bool is_binary_format = (result_len == 4);
+                               /* Ensure correct statfile ID assignment */
+                               statfile_rt->id = st->id;
+
+                               /* Process token results for this statfile (Redis array index redis_idx) */
+                               lua_rawgeti(L, 5, redis_idx); /* Get token_results[redis_idx] */
+                               if (lua_istable(L, -1)) {
+                                       /* Parse token results into statfile runtime */
+                                       auto *res = new std::vector<std::pair<int, float>>();
+
+                                       lua_pushnil(L); /* First key for iteration */
+                                       while (lua_next(L, -2) != 0) {
+                                               if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) {
+                                                       lua_rawgeti(L, -1, 1); /* token_index */
+                                                       lua_rawgeti(L, -2, 2); /* token_count */
+
+                                                       if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) {
+                                                               int token_idx = lua_tointeger(L, -2);
+                                                               float token_count = lua_tonumber(L, -1);
+                                                               res->emplace_back(token_idx, token_count);
+                                                       }
+
+                                                       lua_pop(L, 2); /* Pop token_index and token_count */
+                                               }
+                                               lua_pop(L, 1); /* Pop value, keep key for next iteration */
+                                       }
 
-               if (is_binary_format) {
-                       /* Binary format: [learned_ham, learned_spam, ham_tokens, spam_tokens] */
+                                       statfile_rt->set_results(res);
+                               }
+                               lua_pop(L, 1); /* Pop token_results[redis_idx] */
 
-                       /* Find the opposite runtime for binary classification compatibility */
-                       const char *opposite_label;
-                       if (rt->stcf->class_name) {
-                               /* Multi-class: find a different class (simplified for now) */
-                               opposite_label = strcmp(get_class_label(rt->stcf), "S") == 0 ? "H" : "S";
+                               cur = g_list_next(cur);
+                               redis_idx++;
                        }
-                       else {
-                               /* Binary: use opposite spam/ham */
-                               opposite_label = rt->stcf->is_spam ? "H" : "S";
-                       }
-                       auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
-                                                                                                                                                                                  rt->redis_object_expanded,
-                                                                                                                                                                                  opposite_label);
+               }
 
-                       if (!opposite_rt_maybe) {
-                               msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
-                               return 0;
-                       }
+               /* Clean up stack */
+               lua_pop(L, 2); /* Pop learned_counts and token_results */
 
-                       /* Extract values from the result table at position 3 */
-                       lua_rawgeti(L, 3, 1); /* learned_ham -> position 4 */
-                       lua_rawgeti(L, 3, 2); /* learned_spam -> position 5 */
-                       lua_rawgeti(L, 3, 3); /* ham_tokens -> position 6 */
-                       lua_rawgeti(L, 3, 4); /* spam_tokens -> position 7 */
+               /* Process tokens for all runtimes */
+               g_assert(rt->tokens != nullptr);
 
-                       unsigned learned_ham = lua_tointeger(L, 4);
-                       unsigned learned_spam = lua_tointeger(L, 5);
+               /* Process tokens for all statfiles */
+               if (rt->stcf->clcf && rt->stcf->clcf->statfiles) {
+                       GList *cur = rt->stcf->clcf->statfiles;
 
-                       if (rt->stcf->is_spam || (rt->stcf->class_name && strcmp(get_class_label(rt->stcf), "S") == 0)) {
-                               /* Current runtime is spam, use spam data */
-                               filler_func(rt, L, learned_spam, 7);                       /* spam_tokens at position 7 */
-                               filler_func(opposite_rt_maybe.value(), L, learned_ham, 6); /* ham_tokens at position 6 */
-                       }
-                       else {
-                               /* Current runtime is ham, use ham data */
-                               filler_func(rt, L, learned_ham, 6);                         /* ham_tokens at position 6 */
-                               filler_func(opposite_rt_maybe.value(), L, learned_spam, 7); /* spam_tokens at position 7 */
-                       }
+                       while (cur) {
+                               auto *stcf = (struct rspamd_statfile_config *) cur->data;
+                               const char *class_label = get_class_label(stcf);
 
-                       /* Clean up the stack - pop the 4 extracted values */
-                       lua_pop(L, 4);
+                               auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+                                                                                                                                                                         rt->redis_object_expanded,
+                                                                                                                                                                         class_label);
+                               if (maybe_rt) {
+                                       auto *statfile_rt = maybe_rt.value();
+                                       statfile_rt->process_tokens(rt->tokens);
+                               }
 
-                       /* Process all tokens */
-                       g_assert(rt->tokens != nullptr);
-                       rt->process_tokens(rt->tokens);
-                       opposite_rt_maybe.value()->process_tokens(rt->tokens);
+                               cur = g_list_next(cur);
+                       }
                }
                else {
-                       /* Multi-class format: [learned_counts_table, outputs_table] */
-
-                       /* Get learned counts table (index 1) and outputs table (index 2) */
-                       lua_rawgeti(L, 3, 1); /* learned_counts */
-                       lua_rawgeti(L, 3, 2); /* outputs */
-
-                       /* Iterate through all class labels to fill all runtimes */
-                       if (rt->stcf->clcf && rt->stcf->clcf->class_labels) {
-                               GHashTableIter iter;
-                               gpointer key, value;
-                               g_hash_table_iter_init(&iter, rt->stcf->clcf->class_labels);
-
-                               while (g_hash_table_iter_next(&iter, &key, &value)) {
-                                       const char *class_label = (const char *) value;
-
-                                       /* Find runtime for this class */
-                                       auto class_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
-                                                                                                                                                                                               rt->redis_object_expanded,
-                                                                                                                                                                                               class_label);
-
-                                       if (class_rt_maybe) {
-                                               auto *class_rt = class_rt_maybe.value();
-
-                                               /* Get learned count for this class */
-                                               lua_pushstring(L, class_label);
-                                               lua_gettable(L, -3); /* learned_counts[class_label] */
-                                               unsigned learned = lua_tointeger(L, -1);
-                                               lua_pop(L, 1);
-
-                                               /* Get outputs for this class */
-                                               lua_pushstring(L, class_label);
-                                               lua_gettable(L, -2); /* outputs[class_label] */
-                                               int outputs_pos = lua_gettop(L);
-
-                                               filler_func(class_rt, L, learned, outputs_pos);
-                                               lua_pop(L, 1);
-                                       }
-                               }
-                       }
-
-                       lua_pop(L, 2); /* Pop learned_counts and outputs tables */
-
-                       /* Process tokens for all runtimes */
-                       g_assert(rt->tokens != nullptr);
+                       /* Fallback: just process the main runtime */
                        rt->process_tokens(rt->tokens);
                }
-
-               /* Tokens processed - no need to set flags in multi-class approach */
        }
        else {
                /* Error message is on index 3 */
index 836c27c88d42307eacebb141bb2d45458ffd8510..4d070ee2025b93d4a1f6cbc0c2c494b5b5963565 100644 (file)
@@ -302,15 +302,16 @@ bayes_classify_token_multiclass(struct rspamd_classifier *ctx,
                val = tok->values[id];
 
                if (val > 0) {
-                       /* Find which class this statfile belongs to */
-                       for (j = 0; j < cl->num_classes; j++) {
-                               if (st->stcf->class_name &&
-                                       strcmp(st->stcf->class_name, cl->class_names[j]) == 0) {
-                                       class_counts[j] += val;
-                                       total_count += val;
-                                       cl->total_hits += val;
-                                       break;
-                               }
+                       /* Direct O(1) class index lookup instead of O(N) string comparison */
+                       if (st->stcf->class_name && st->stcf->class_index < cl->num_classes) {
+                               unsigned int class_idx = st->stcf->class_index;
+                               class_counts[class_idx] += val;
+                               total_count += val;
+                               cl->total_hits += val;
+                       }
+                       else {
+                               msg_debug_bayes("invalid class_index %ud >= %ud for statfile %s",
+                                                               st->stcf->class_index, cl->num_classes, st->stcf->symbol);
                        }
                }
        }
@@ -348,7 +349,7 @@ bayes_classify_token_multiclass(struct rspamd_classifier *ctx,
 
                if (tok->t1 && tok->t2) {
                        msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %.3f, total_count: %ud, "
-                                                       "processed for %u classes",
+                                                       "processed for %ud classes",
                                                        token_type, tok->data,
                                                        (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
                                                        (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
@@ -385,13 +386,27 @@ bayes_classify_multiclass(struct rspamd_classifier *ctx,
        cl.cfg = ctx->cfg;
 
        /* Get class information from classifier config */
-       if (!ctx->cfg->class_names || ctx->cfg->class_names->len < 2) {
-               msg_debug_bayes("insufficient classes for multiclass classification");
+       if (!ctx->cfg->class_names) {
+               msg_debug_bayes("no class_names array in classifier config");
+               return TRUE; /* Fall back to binary mode */
+       }
+       if (ctx->cfg->class_names->len < 2) {
+               msg_debug_bayes("insufficient classes: %ud < 2", (unsigned int) ctx->cfg->class_names->len);
+               return TRUE; /* Fall back to binary mode */
+       }
+       if (!ctx->cfg->class_names->pdata) {
+               msg_debug_bayes("class_names->pdata is NULL");
                return TRUE; /* Fall back to binary mode */
        }
 
        cl.num_classes = ctx->cfg->class_names->len;
        cl.class_names = (char **) ctx->cfg->class_names->pdata;
+
+       /* Debug: verify class names are accessible */
+       msg_debug_bayes("multiclass setup: ctx->cfg->class_names=%p, len=%ud, pdata=%p",
+                                       ctx->cfg->class_names, (unsigned int) ctx->cfg->class_names->len, ctx->cfg->class_names->pdata);
+       msg_debug_bayes("multiclass setup: cl.num_classes=%ud, cl.class_names=%p",
+                                       cl.num_classes, cl.class_names);
        cl.class_log_probs = g_alloca(cl.num_classes * sizeof(double));
        cl.class_learns = g_alloca(cl.num_classes * sizeof(uint64_t));
 
@@ -459,6 +474,22 @@ bayes_classify_multiclass(struct rspamd_classifier *ctx,
        }
 
        if (cl.processed_tokens == 0) {
+               /* Debug: check why no tokens were processed */
+               msg_debug_bayes("examining token values for debugging:");
+               for (i = 0; i < MIN(tokens->len, 10); i++) { /* Check first 10 tokens */
+                       tok = g_ptr_array_index(tokens, i);
+                       for (j = 0; j < ctx->statfiles_ids->len; j++) {
+                               id = g_array_index(ctx->statfiles_ids, int, j);
+                               if (tok->values[id] > 0) {
+                                       struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id);
+                                       msg_debug_bayes("token %ud: values[%d] = %.2f (class=%s, symbol=%s)",
+                                                                       i, id, tok->values[id],
+                                                                       st->stcf->class_name ? st->stcf->class_name : "unknown",
+                                                                       st->stcf->symbol);
+                               }
+                       }
+               }
+
                msg_info_bayes("no tokens found in bayes database "
                                           "(%ud total tokens, %ud text tokens), ignore stats",
                                           tokens->len, text_tokens);