]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Fix multiclass Bayes learn cache and classifier isolation
authorDmitriy Alekseev <1865999+dragoangel@users.noreply.github.com>
Fri, 20 Feb 2026 16:31:33 +0000 (17:31 +0100)
committerDmitriy Alekseev <1865999+dragoangel@users.noreply.github.com>
Fri, 20 Feb 2026 16:31:33 +0000 (17:31 +0100)
Propagate task->classifier from the HTTP header so class-targeted
learning skips unrelated classifiers early. Validate the requested
class exists in a statfile before tokenisation to fail fast with a
clear 404 error on misconfigured setups.

Replace numeric class_id hash cache keys in learned_ids with direct class
name strings throughout the Redis cache layer (C, Lua, Redis scripts) to
fix uint64_t precision loss through Lua's 53-bit doubles, which caused
equality checks to always fail for arbitrary multiclass class names.

Add RSPAMD_FLAG_CLASSIFIER_MULTICLASS at config load time
to route probability lookups to the correct result type per classifier.

Store multiclass and binary Bayes results under per-classifier mempool
keys (multiclass_result:<name>, bayes_prob:<name>) and update the
set/get API to take a classifier_name parameter, preventing
cross-contamination when multiple classifiers are configured.

Switch multiclass result allocation from heap (g_new0/g_new) to task
memory pool, eliminating the need for rspamd_multiclass_result_free.

Inject can_learn_prob and can_learn_class into the mempool before
invoking Lua learn conditions, replacing the legacy asymmetric
spam_min/ham_max pair with a unified min_prob threshold.

Remove the cl_skipped gate in rspamd_stat_cache_check so the Redis
learned_ids cache is always consulted before learn conditions run for
proper logging and response handling on relearning same emails.

lualib/lua_bayes_learn.lua
lualib/lua_bayes_redis.lua
lualib/redis_scripts/bayes_cache_check.lua
lualib/redis_scripts/bayes_cache_learn.lua
src/controller.c
src/libserver/cfg_file.h
src/libserver/cfg_rcl.cxx
src/libstat/classifiers/bayes.c
src/libstat/learn_cache/redis_cache.cxx
src/libstat/stat_api.h
src/libstat/stat_process.c

index d44b78eee08b6fc50dce72b76dbfc0ceae230055..5e7d0cc449420ef1a5dcfed64d988495db4566ed 100644 (file)
@@ -174,10 +174,14 @@ local default_can_learn_settings = {
   },
   probability_check = {
     enabled = true,
-    variable = 'bayes_prob',
+    -- 'can_learn_prob' is written by the C layer (rspamd_stat_classifier_is_skipped)
+    -- before invoking this condition. It holds the per-classifier, per-class
+    -- probability so each classifier's can_learn decision is independent.
+    variable = 'can_learn_prob',
     ctype = 'double',
-    spam_min = 0.95,
-    ham_max = 0.05,
+    -- Unified threshold: >= min_prob means "already in this class, skip learning".
+    -- Replaces the old asymmetric spam_min/ham_max pair.
+    min_prob = 0.95,
     skip_for_unlearn = false,
     require_value = false,
   },
@@ -250,11 +254,6 @@ exports.unregister_autolearn_guard = function(name)
   unregister_guard(autolearn_guards, name)
 end
 
-local function format_probability_message(ctx, prob, cl)
-  local pct = math.abs((prob - 0.5) * 200.0)
-
-  return string.format('already in class %s; probability %.2f%%', cl, pct)
-end
 
 --- Determines if a message can be learned by Bayes
 -- @param task rspamd_task
@@ -363,22 +362,25 @@ exports.can_learn = function(task, is_spam, is_unlearn, overrides)
         if probability_opts.check and type(probability_opts.check) == 'function' then
           in_class, guard_msg = probability_opts.check(ctx, prob)
         else
-          if is_spam then
-            in_class = prob >= (probability_opts.spam_min or 0.95)
-          else
-            in_class = prob <= (probability_opts.ham_max or 0.05)
-          end
+          -- Unified check: high probability means the message is already confidently
+          -- in the target class (works for both binary spam/ham and multiclass).
+          -- can_learn_prob is set per-classifier by C so there is no cross-contamination.
+          in_class = prob >= (probability_opts.min_prob or
+                              probability_opts.spam_min or 0.95)
         end
 
         if in_class then
-          local cl = is_spam and 'spam' or 'ham'
+          -- class name is written by C before invoking this condition
+          local cl = task:get_mempool():get_variable('can_learn_class') or
+              (is_spam and 'spam' or 'ham')
           local reason
 
           if probability_opts.message_formatter and type(probability_opts.message_formatter) == 'function' then
             reason = probability_opts.message_formatter(ctx, prob, cl) or guard_msg
           end
 
-          reason = reason or guard_msg or format_probability_message(ctx, prob, cl)
+          reason = reason or guard_msg or
+              string.format('already in class %s; probability %.2f%%', cl, prob * 100.0)
 
           ctx.result.guard = 'probability_check'
           ctx.result.reason = reason
index 91466af6a06be85230b79ca78755468295ee5069..150a51fce013857be1376b56da3fc96a126a585b 100644 (file)
@@ -230,7 +230,10 @@ local function gen_cache_check_functor(redis_params, check_script_id, conf)
       if err then
         callback(task, false, err)
       else
-        if type(data) == 'number' then
+        -- The cached value is now a class name string (e.g. "spam", "ham",
+        -- "transactional").  Previously it was a number (numeric class_id hash),
+        -- but that caused precision loss for large uint64 hashes in Lua doubles.
+        if type(data) == 'string' then
           callback(task, true, data)
         else
           callback(task, false, 'not found')
@@ -247,16 +250,16 @@ end
 
 local function gen_cache_learn_functor(redis_params, learn_script_id, conf)
   local packed_conf = ucl.to_format(conf, 'msgpack')
-  return function(task, cache_id, class_name, class_id)
+  return function(task, cache_id, class_name)
     local function learn_redis_cb(err, data)
       lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
     end
 
-    lua_util.debugm(N, task, 'try to learn cache: %s as %s (id=%s)', cache_id, class_name, class_id)
+    lua_util.debugm(N, task, 'try to learn cache: %s as %s', cache_id, class_name)
     lua_redis.exec_redis_script(learn_script_id,
         { task = task, is_write = true, key = cache_id },
         learn_redis_cb,
-        { cache_id, tostring(class_id), packed_conf })
+        { cache_id, class_name, packed_conf })
   end
 end
 
index f1ffc2b84eb26c2323b3a1ce817bd1b4d8e59b9f..912396e4c2cccccb6e73acf5f091d2ce3d23b63f 100644 (file)
@@ -13,7 +13,10 @@ for i = 0, conf.cache_max_keys do
   local have = redis.call('HGET', prefix, cache_id)
 
   if have then
-    return tonumber(have)
+    -- Return the raw string value (class name, e.g. "spam", "ham", "transactional").
+    -- Previously tonumber() was used here, but 64-bit integer class IDs exceed
+    -- Lua's 53-bit double precision, corrupting the value for multiclass classifiers.
+    return have
   end
 end
 
index a7c9ac443cafb343ebd1135d22fbf829e59738a0..d3ec095a0af940d0d6948d5a16308ee2066ef3ca 100644 (file)
@@ -1,15 +1,20 @@
--- Lua script to perform cache checking for bayes classification (multi-class)
+-- Lua script to perform cache learning for bayes classification (multi-class)
 -- This script accepts the following parameters:
 -- key1 - cache id
--- key2 - class_id (numeric hash of class name, computed by C side)
+-- key2 - class name string (e.g. "spam", "ham", "transactional")
 -- key3 - configuration table in message pack
+--
+-- The cache value stored in Redis is the class name string.  A numeric class_id
+-- hash was used previously, but uint64_t values > 2^53 lose precision when
+-- round-tripped through Lua doubles, so the equality check on retrieval was
+-- unreliable for arbitrary multiclass names.
 
 local cache_id = KEYS[1]
-local class_id = KEYS[2]
+local class_name = KEYS[2]
 local conf = cmsgpack.unpack(KEYS[3])
 
--- Use class_id directly as cache value
-local cache_value = tostring(class_id)
+-- Store the class name directly as the cache value
+local cache_value = class_name
 cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
 
 -- Try each prefix that is in Redis (as some other instance might have set it)
index ebd309798ba18a1cde5c0ebfb64d7b77e0d2827b..c349ddc268eb4356d49358e4b21f03f258a51dc8 100644 (file)
@@ -2163,6 +2163,7 @@ rspamd_controller_handle_learn_common(
        cl_header = rspamd_http_message_find_header(msg, "classifier");
        if (cl_header) {
                session->classifier = rspamd_mempool_ftokdup(session->pool, cl_header);
+               task->classifier = session->classifier;
        }
        else {
                session->classifier = NULL;
@@ -2278,6 +2279,7 @@ rspamd_controller_handle_learnclass(
        cl_header = rspamd_http_message_find_header(msg, "classifier");
        if (cl_header) {
                session->classifier = rspamd_mempool_ftokdup(session->pool, cl_header);
+               task->classifier = session->classifier;
        }
        else {
                session->classifier = NULL;
index 5b9022e2dbbaff4c7b682c4534018ebb6fb8c696..a14624ce8a358f6123c21d2820efb4c105937a92 100644 (file)
@@ -165,6 +165,13 @@ struct rspamd_tokenizer_config {
  * No backend required for classifier
  */
 #define RSPAMD_FLAG_CLASSIFIER_NO_BACKEND (1 << 2)
+/*
+ * Set if classifier has at least one class that is neither "spam" nor "ham"
+ * (i.e. a genuinely multiclass classifier, not just a binary spam/ham one).
+ * When set, can_learn uses multiclass_result:<name> for probability checks
+ * instead of the legacy bayes_prob binary variable.
+ */
+#define RSPAMD_FLAG_CLASSIFIER_MULTICLASS (1 << 3)
 
 /**
  * Classifier config definition
index 5045624490dc52ab0109d7608b1576e24d16856a..add33f39bc87bc7cbd7dc78dc92dae49136d8f12 100644 (file)
@@ -1506,6 +1506,7 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool,
 
                if (has_explicit_classes) {
                        ccf->class_names = g_ptr_array_new();
+                       bool has_custom_class = false;
 
                        cur = ccf->statfiles;
                        while (cur) {
@@ -1526,9 +1527,22 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool,
                                                stcf->class_index = ccf->class_names->len;
                                                g_ptr_array_add(ccf->class_names, g_strdup(stcf->class_name));
                                        }
+
+                                       /* Detect genuinely multiclass: any class that is not spam/ham */
+                                       if (strcmp(stcf->class_name, "spam") != 0 &&
+                                               strcmp(stcf->class_name, "S") != 0 &&
+                                               strcmp(stcf->class_name, "ham") != 0 &&
+                                               strcmp(stcf->class_name, "H") != 0) {
+                                               has_custom_class = true;
+                                       }
                                }
                                cur = g_list_next(cur);
                        }
+
+                       if (has_custom_class) {
+                               ccf->flags |= RSPAMD_FLAG_CLASSIFIER_MULTICLASS;
+                               msg_debug("classifier %s flagged as MULTICLASS", ccf->name ? ccf->name : "(unnamed)");
+                       }
                }
        }
 
index afa12e7c621c7b37c4ddf5802beaf4ed9b854869..d008f2735ab36b7796a05bc3db92404b11ce55c0 100644 (file)
@@ -672,20 +672,22 @@ bayes_classify_multiclass(struct rspamd_classifier *ctx,
                confidence = normalized_probs[winning_class_idx];
        }
 
-       /* Create and store multiclass result */
-       result = g_new0(rspamd_multiclass_result_t, 1);
-       result->class_names = g_new(char *, cl.num_classes);
-       result->probabilities = g_new(double, cl.num_classes);
+       /* Create and store multiclass result — all pool-allocated, no destructor needed */
+       result = rspamd_mempool_alloc0(task->task_pool, sizeof(*result));
+       result->class_names = rspamd_mempool_alloc(task->task_pool, cl.num_classes * sizeof(char *));
+       result->probabilities = rspamd_mempool_alloc(task->task_pool, cl.num_classes * sizeof(double));
        result->num_classes = cl.num_classes;
-       result->winning_class = cl.class_names[winning_class_idx]; /* Reference, not copy */
+       result->winning_class = cl.class_names[winning_class_idx]; /* Reference into cfg, valid for task lifetime */
        result->confidence = confidence;
 
        for (i = 0; i < cl.num_classes; i++) {
-               result->class_names[i] = g_strdup(cl.class_names[i]);
+               result->class_names[i] = cl.class_names[i] ?
+                       rspamd_mempool_strdup(task->task_pool, cl.class_names[i]) : NULL;
                result->probabilities[i] = normalized_probs[i];
        }
 
-       rspamd_task_set_multiclass_result(task, result);
+       /* Store via unified API — keyed as "multiclass_result:<name>" (or "" for unnamed) */
+       rspamd_task_set_multiclass_result(task, result, ctx->cfg->name);
 
        msg_info_bayes("MULTICLASS_RESULT: winning_class='%s', confidence=%.3f, normalized_prob=%.3f, tokens=%uL",
                                   cl.class_names[winning_class_idx], confidence,
@@ -922,6 +924,19 @@ bayes_classify(struct rspamd_classifier *ctx,
        pprob = rspamd_mempool_alloc(task->task_pool, sizeof(*pprob));
        *pprob = final_prob;
        rspamd_mempool_set_variable(task->task_pool, "bayes_prob", pprob, NULL);
+       /* Also store per-classifier key so can_learn reads the right classifier's result
+        * when multiple binary classifiers are present. Always written (using "" for
+        * unnamed classifiers) so the lookup in rspamd_stat_classifier_is_skipped
+        * always finds it. */
+       {
+               const char *cl_name = (ctx->cfg->name && *ctx->cfg->name) ? ctx->cfg->name : "";
+               gsize key_len = strlen("bayes_prob:") + strlen(cl_name) + 1;
+               char *per_cl_key = rspamd_mempool_alloc(task->task_pool, key_len);
+               rspamd_snprintf(per_cl_key, key_len, "bayes_prob:%s", cl_name);
+               double *pprob_cl = rspamd_mempool_alloc(task->task_pool, sizeof(*pprob_cl));
+               *pprob_cl = final_prob;
+               rspamd_mempool_set_variable(task->task_pool, per_cl_key, pprob_cl, NULL);
+       }
 
        if (cl.processed_tokens > 0 && fabs(final_prob - 0.5) > 0.05) {
                /* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */
index 81f31b834b958a1914ec5b6f922bbf84a0e69f19..b4d35cea7dccf1eee3e1ffcf8ccaa527b40a2770 100644 (file)
@@ -152,33 +152,6 @@ rspamd_stat_cache_redis_runtime(struct rspamd_task *task,
        return (void *) ctx;
 }
 
-/* Get class ID using rspamd_cryptobox_fast_hash */
-static uint64_t
-rspamd_stat_cache_get_class_id(const char *class_name)
-{
-       if (!class_name) {
-               return 0;
-       }
-
-       if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
-               return 1;
-       }
-       else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
-               return 0;
-       }
-       else {
-               /* For other classes, use rspamd_cryptobox_fast_hash */
-               uint64_t hash = rspamd_cryptobox_fast_hash(class_name, strlen(class_name), 0);
-
-               /* Ensure we don't get 0 or 1 (reserved for ham/spam) */
-               if (hash == 0 || hash == 1) {
-                       hash += 2;
-               }
-
-               return hash;
-       }
-}
-
 static int
 rspamd_stat_cache_checked(lua_State *L)
 {
@@ -191,7 +164,12 @@ rspamd_stat_cache_checked(lua_State *L)
        auto res = lua_toboolean(L, 2);
 
        if (res) {
-               auto val = lua_tointeger(L, 3);
+               /* The cached value is the class name string (e.g. "spam", "ham", "transactional").
+                * Previously this was stored as a 64-bit integer hash, but uint64_t values
+                * larger than 2^53 lose precision when passed through Lua doubles, causing
+                * the equality check to always fail and forcing UNLEARN instead of
+                * ALREADY_LEARNED for multiclass classifiers. */
+               const char *cached_class = lua_tostring(L, 3);
 
                /* Get the class being learned */
                const char *autolearn_class = rspamd_task_get_autolearn_class(task);
@@ -205,11 +183,9 @@ rspamd_stat_cache_checked(lua_State *L)
                        }
                }
 
-               if (autolearn_class) {
-                       uint64_t expected_id = rspamd_stat_cache_get_class_id(autolearn_class);
-
-                       if ((uint64_t) val == expected_id) {
-                               /* Already learned */
+               if (autolearn_class && cached_class) {
+                       if (strcmp(cached_class, autolearn_class) == 0) {
+                               /* Already learned as the same class */
                                msg_info_task("<%s> has been already "
                                                          "learned as %s, ignore it",
                                                          MESSAGE_FIELD(task, message_id),
@@ -217,10 +193,10 @@ rspamd_stat_cache_checked(lua_State *L)
                                task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
                        }
                        else {
-                               /* Different class learned, unlearn flag */
-                               msg_debug_task("<%s> cached value %L != expected %uL for class %s, will unlearn",
+                               /* Different class was learned previously, mark for unlearn */
+                               msg_debug_task("<%s> cached class '%s' != requested '%s', will unlearn",
                                                           MESSAGE_FIELD(task, message_id),
-                                                          val, expected_id, autolearn_class);
+                                                          cached_class, autolearn_class);
                                task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
                        }
                }
@@ -291,12 +267,9 @@ int rspamd_stat_cache_redis_learn(struct rspamd_task *task,
                autolearn_class = is_spam ? "spam" : "ham";
        }
 
-       /* Push class name and class ID */
        lua_pushstring(L, autolearn_class);
-       uint64_t class_id = rspamd_stat_cache_get_class_id(autolearn_class);
-       lua_pushinteger(L, class_id);
 
-       if (lua_pcall(L, 4, 0, err_idx) != 0) {
+       if (lua_pcall(L, 3, 0, err_idx) != 0) {
                msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
                lua_settop(L, err_idx - 1);
                return RSPAMD_LEARN_IGNORE;
index aa6111a8b2dfccf612f1883b25205292994a58f1..c71e6fae98a52ed003f54f13ba879c7576666e2c 100644 (file)
@@ -149,20 +149,20 @@ typedef struct {
 } rspamd_multiclass_result_t;
 
 /**
- * Set multi-class classification result for a task
+ * Set multi-class classification result for a task.
+ * @param classifier_name  classifier name, or NULL/"" for unnamed classifiers
+ * Result is stored under the key "multiclass_result:<classifier_name>".
  */
 void rspamd_task_set_multiclass_result(struct rspamd_task *task,
-                                                                          rspamd_multiclass_result_t *result);
+                                                                          rspamd_multiclass_result_t *result,
+                                                                          const char *classifier_name);
 
 /**
- * Get multi-class classification result from a task
+ * Get multi-class classification result from a task.
+ * @param classifier_name  classifier name, or NULL/"" for unnamed classifiers
  */
-rspamd_multiclass_result_t *rspamd_task_get_multiclass_result(struct rspamd_task *task);
-
-/**
- * Free multi-class result structure
- */
-void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result);
+rspamd_multiclass_result_t *rspamd_task_get_multiclass_result(struct rspamd_task *task,
+                                                                                                                          const char *classifier_name);
 
 /**
  * Set autolearn class for a task
index 655e0bdba84290477a8215ddebdfd2af55e31c48..f6db2fa550dee0a0448314c2092739223b02a8a1 100644 (file)
 
 static const double similarity_threshold = 80.0;
 
-void rspamd_task_set_multiclass_result(struct rspamd_task *task, rspamd_multiclass_result_t *result)
+void rspamd_task_set_multiclass_result(struct rspamd_task *task,
+                                                                          rspamd_multiclass_result_t *result,
+                                                                          const char *classifier_name)
 {
        g_assert(task != NULL);
        g_assert(result != NULL);
 
-       rspamd_mempool_set_variable(task->task_pool, "multiclass_bayes_result", result,
-                                                               (rspamd_mempool_destruct_t) rspamd_multiclass_result_free);
+       /* Unified key: "multiclass_result:<name>", empty string for unnamed classifiers */
+       const char *cl_name = (classifier_name && *classifier_name) ? classifier_name : "";
+       gsize key_len = strlen("multiclass_result:") + strlen(cl_name) + 1;
+       char *key = rspamd_mempool_alloc(task->task_pool, key_len);
+       rspamd_snprintf(key, key_len, "multiclass_result:%s", cl_name);
+
+       /* NULL destructor — result is pool-allocated */
+       rspamd_mempool_set_variable(task->task_pool, key, result, NULL);
 }
 
 rspamd_multiclass_result_t *
-rspamd_task_get_multiclass_result(struct rspamd_task *task)
+rspamd_task_get_multiclass_result(struct rspamd_task *task, const char *classifier_name)
 {
        g_assert(task != NULL);
 
-       return (rspamd_multiclass_result_t *) rspamd_mempool_get_variable(task->task_pool,
-                                                                                                                                         "multiclass_bayes_result");
-}
+       const char *cl_name = (classifier_name && *classifier_name) ? classifier_name : "";
+       gsize key_len = strlen("multiclass_result:") + strlen(cl_name) + 1;
+       char *key = rspamd_mempool_alloc(task->task_pool, key_len);
+       rspamd_snprintf(key, key_len, "multiclass_result:%s", cl_name);
 
-void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result)
-{
-       if (result == NULL) {
-               return;
-       }
-
-       g_free(result->class_names);
-       g_free(result->probabilities);
-       /* winning_class is a reference, not owned - don't free */
-       g_free(result);
+       return (rspamd_multiclass_result_t *) rspamd_mempool_get_variable(task->task_pool, key);
 }
 
 void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name)
@@ -280,12 +280,73 @@ void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx,
 
 static gboolean
 rspamd_stat_classifier_is_skipped(struct rspamd_task *task,
-                                                                 struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam)
+                                                                 struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam,
+                                                                 const char *learn_class_name)
 {
        GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions;
        lua_State *L = task->cfg->lua_state;
        gboolean ret = FALSE;
 
+       /*
+        * Before calling the Lua learn condition, populate "can_learn_prob" in the
+        * mempool with the probability from THIS specific classifier's result.
+        *
+        * Binary classifiers:   read "bayes_prob:<name>" (set by bayes_classify())
+        * Multiclass classifiers: read "multiclass_result:<name>", find target class
+        *
+        * Per-classifier keys prevent cross-contamination when multiple classifiers
+        * of the same type are configured. Falls back to NULL (= skip probability
+        * check) if this classifier has no result yet (e.g. zero learns).
+        */
+       if (is_learn && learn_class_name != NULL) {
+               double *can_learn_prob_ptr = NULL;
+               /* Use "" for unnamed classifiers — matches what bayes_classify() stores */
+               const char *cl_name = (cl->cfg->name && *cl->cfg->name) ? cl->cfg->name : "";
+
+               if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_MULTICLASS) {
+                       /* Look up THIS classifier's multiclass result via the unified API */
+                       rspamd_multiclass_result_t *mc_result =
+                               rspamd_task_get_multiclass_result(task, cl_name);
+                       if (mc_result != NULL) {
+                               for (unsigned int mci = 0; mci < mc_result->num_classes; mci++) {
+                                       if (mc_result->class_names[mci] != NULL &&
+                                               strcmp(mc_result->class_names[mci], learn_class_name) == 0) {
+                                               can_learn_prob_ptr = rspamd_mempool_alloc(task->task_pool,
+                                                                                                                                 sizeof(double));
+                                               *can_learn_prob_ptr = mc_result->probabilities[mci];
+                                               break;
+                                       }
+                               }
+                       }
+                       /* NULL means classifier has no result (zero learns) → skip prob check → allow */
+               }
+               else {
+                       /* Look up THIS classifier's binary bayes_prob by name.
+                        * bayes_prob is the spam probability (0=ham, 1=spam).
+                        * Convert to "probability of the class being learned" so the
+                        * unified >= threshold check works correctly for both directions:
+                        * learning spam: use raw prob (high = already spam = skip)
+                        * learning ham:  use 1-prob  (high = already ham = skip) */
+                       gsize key_len = strlen("bayes_prob:") + strlen(cl_name) + 1;
+                       char *per_cl_key = rspamd_mempool_alloc(task->task_pool, key_len);
+                       rspamd_snprintf(per_cl_key, key_len, "bayes_prob:%s", cl_name);
+                       double *raw_prob = (double *) rspamd_mempool_get_variable(task->task_pool,
+                                                                                                                                         per_cl_key);
+                       if (raw_prob != NULL) {
+                               can_learn_prob_ptr = rspamd_mempool_alloc(task->task_pool, sizeof(double));
+                               gboolean learning_ham = (strcmp(learn_class_name, "ham") == 0 ||
+                                                                                strcmp(learn_class_name, "H") == 0);
+                               *can_learn_prob_ptr = learning_ham ? (1.0 - *raw_prob) : *raw_prob;
+                       }
+               }
+
+               rspamd_mempool_set_variable(task->task_pool, "can_learn_prob",
+                                                                       can_learn_prob_ptr, NULL);
+               /* Also store the class name so can_learn() can include it in log messages */
+               rspamd_mempool_set_variable(task->task_pool, "can_learn_class",
+                                                                       (gpointer) learn_class_name, NULL);
+       }
+
        while (cur) {
                int cb_ref = GPOINTER_TO_INT(cur->data);
                int old_top = lua_gettop(L);
@@ -371,6 +432,9 @@ rspamd_stat_preprocess(struct rspamd_stat_ctx *st_ctx,
                g_ptr_array_index(task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE);
        }
 
+       /* When learning a specific class, retrieve it once for use in the loop below */
+       const char *learn_class_name = is_learn ? rspamd_task_get_autolearn_class(task) : NULL;
+
        for (i = 0; i < st_ctx->classifiers->len; i++) {
                struct rspamd_classifier *cl = g_ptr_array_index(st_ctx->classifiers, i);
                gboolean skip_classifier = FALSE;
@@ -379,9 +443,38 @@ rspamd_stat_preprocess(struct rspamd_stat_ctx *st_ctx,
                        skip_classifier = TRUE;
                }
                else {
-                       if (rspamd_stat_classifier_is_skipped(task, cl, is_learn, is_spam)) {
+                       /* Respect task->classifier filter: if a specific classifier was
+                        * requested, skip all others without running can_learn on them */
+                       if (is_learn && task->classifier != NULL &&
+                                       (cl->cfg->name == NULL ||
+                                        g_ascii_strcasecmp(task->classifier, cl->cfg->name) != 0)) {
                                skip_classifier = TRUE;
                        }
+
+                       /* For class-based learning: skip classifiers that don't have the
+                        * target class at all — no need to run can_learn on them. */
+                       if (!skip_classifier && is_learn && learn_class_name != NULL) {
+                               gboolean cl_has_class = FALSE;
+                               for (int j = 0; j < cl->statfiles_ids->len; j++) {
+                                       int id = g_array_index(cl->statfiles_ids, int, j);
+                                       struct rspamd_statfile *cst = g_ptr_array_index(st_ctx->statfiles, id);
+                                       if (cst->stcf->class_name &&
+                                                       strcmp(cst->stcf->class_name, learn_class_name) == 0) {
+                                               cl_has_class = TRUE;
+                                               break;
+                                       }
+                               }
+                               if (!cl_has_class) {
+                                       skip_classifier = TRUE;
+                               }
+                       }
+
+                       if (!skip_classifier) {
+                               if (rspamd_stat_classifier_is_skipped(task, cl, is_learn, is_spam,
+                                                                                                         learn_class_name)) {
+                                       skip_classifier = TRUE;
+                               }
+                       }
                }
 
                if (skip_classifier) {
@@ -621,8 +714,6 @@ rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx,
        struct rspamd_classifier *cl, *sel = NULL;
        gpointer rt;
        unsigned int i;
-       gboolean any_considered = FALSE;
-       gboolean any_available = FALSE;
 
        /* Check whether we have learned that file */
        for (i = 0; i < st_ctx->classifiers->len; i++) {
@@ -635,29 +726,6 @@ rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx,
                }
 
                sel = cl;
-               any_considered = TRUE;
-
-               /* If classifier was skipped by learn conditions in preprocess, skip cache */
-               gboolean cl_skipped = TRUE;
-               if (task->stat_runtimes != NULL) {
-                       for (int j = 0; j < cl->statfiles_ids->len; j++) {
-                               int id = g_array_index(cl->statfiles_ids, int, j);
-                               if (g_ptr_array_index(task->stat_runtimes, id) != NULL) {
-                                       cl_skipped = FALSE;
-                                       break;
-                               }
-                       }
-               }
-               else {
-                       /* No runtimes prepared means not skipped */
-                       cl_skipped = FALSE;
-               }
-
-               if (cl_skipped) {
-                       continue;
-               }
-
-               any_available = TRUE;
 
                if (sel->cache && sel->cachecf) {
                        rt = cl->cache->runtime(task, sel->cachecf, FALSE);
@@ -716,15 +784,6 @@ rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx,
                }
        }
 
-       /* If we considered classifiers but all were skipped by conditions, stop early */
-       if (any_considered && !any_available) {
-               g_set_error(err, rspamd_stat_quark(), 204, "all learn conditions "
-                                                                                                  "denied learning %s in %s",
-                                       spam ? "spam" : "ham",
-                                       classifier ? classifier : "default classifier");
-               return FALSE;
-       }
-
        if (sel == NULL) {
                if (classifier) {
                        g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
@@ -1284,6 +1343,40 @@ rspamd_stat_learn_class(struct rspamd_task *task,
        }
 
        if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
+               /* Validate that the requested class exists in (at least one statfile of) the
+                * target classifier(s) before doing any further work such as tokenisation,
+                * running learn-conditions, or hitting the cache.  Failing early avoids the
+                * confusing situation where /learnham returns success on a multiclass
+                * classifier that has no "ham" statfile. */
+               gboolean class_valid = FALSE;
+               for (unsigned int ci = 0; ci < st_ctx->classifiers->len; ci++) {
+                       struct rspamd_classifier *cl = g_ptr_array_index(st_ctx->classifiers, ci);
+                       if (classifier != NULL && (cl->cfg->name == NULL ||
+                                       g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
+                               continue;
+                       }
+                       for (unsigned int si = 0; si < cl->statfiles_ids->len; si++) {
+                               int sid = g_array_index(cl->statfiles_ids, int, si);
+                               struct rspamd_statfile *st = g_ptr_array_index(st_ctx->statfiles, sid);
+                               if (st->stcf->class_name &&
+                                               strcmp(st->stcf->class_name, class_name) == 0) {
+                                       class_valid = TRUE;
+                                       break;
+                               }
+                       }
+                       if (class_valid) break;
+               }
+               if (!class_valid) {
+                       if (err && *err == NULL) {
+                               g_set_error(err, rspamd_stat_quark(), 404,
+                                                       "class '%s' is not defined in classifier %s",
+                                                       class_name,
+                                                       classifier ? classifier : "(any)");
+                       }
+                       task->processed_stages |= stage;
+                       return RSPAMD_STAT_PROCESS_ERROR;
+               }
+
                /* Ensure cache comparison uses the exact class we are about to learn */
                rspamd_task_set_autolearn_class(task, class_name);
                /* Process classifiers - determine spam boolean for compatibility */