]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Fix single class fallback
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 29 Jul 2025 08:16:42 +0000 (09:16 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 29 Jul 2025 08:16:42 +0000 (09:16 +0100)
src/libstat/backends/redis_backend.cxx

index b80a7a2bd81a3025b1fdf2eb445fc4ec8d90a422..f355133d403374600b6381ee87d5c29b1b75b3ae 100644 (file)
@@ -914,120 +914,229 @@ rspamd_redis_classified(lua_State *L)
                lua_rawgeti(L, 3, 1); /* learned_counts -> position 4 */
                lua_rawgeti(L, 3, 2); /* token_results -> position 5 */
 
-               /* First, process learned_counts using class_names order */
-               if (lua_istable(L, 4) && rt->stcf->clcf && rt->stcf->clcf->class_names &&
-                       rt->stcf->clcf->class_names->len > 0) {
-                       /* Process each class in the same order as sent to Redis */
-                       for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) {
-                               const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx);
-
-                               /* Find statfile with this class name */
+               /* First, process learned_counts */
+               if (lua_istable(L, 4) && rt->stcf->clcf) {
+                       if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) {
+                               /* Multi-class: use class_names order */
+                               for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) {
+                                       const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx);
+
+                                       /* Find statfile with this class name */
+                                       GList *cur = rt->stcf->clcf->statfiles;
+                                       while (cur) {
+                                               auto *stcf = (struct rspamd_statfile_config *) cur->data;
+                                               if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) {
+                                                       const char *class_label = get_class_label(stcf);
+
+                                                       /* Get the runtime for this statfile */
+                                                       auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(rt->task,
+                                                                                                                                                                                                 rt->redis_object_expanded,
+                                                                                                                                                                                                 class_label);
+                                                       if (maybe_rt) {
+                                                               auto *statfile_rt = maybe_rt.value();
+
+                                                               /* Extract learned count using class index (1-based for Lua) */
+                                                               lua_rawgeti(L, 4, class_idx + 1);
+                                                               if (lua_isnumber(L, -1)) {
+                                                                       statfile_rt->learned = lua_tointeger(L, -1);
+                                                                       msg_debug_bayes("set learned count for class %s (label %s): %L",
+                                                                                                       class_name, class_label, statfile_rt->learned);
+                                                               }
+                                                               lua_pop(L, 1); /* Pop learned_counts[class_idx + 1] */
+                                                       }
+                                                       break; /* Found the statfile for this class */
+                                               }
+                                               cur = g_list_next(cur);
+                                       }
+                               }
+                       }
+                       else {
+                               /* Binary classification: process statfiles in order */
                                GList *cur = rt->stcf->clcf->statfiles;
+                               unsigned int statfile_idx = 0;
                                while (cur) {
                                        auto *stcf = (struct rspamd_statfile_config *) cur->data;
-                                       if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) {
-                                               const char *class_label = get_class_label(stcf);
-
-                                               /* Get the runtime for this statfile */
-                                               auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(rt->task,
-                                                                                                                                                                                         rt->redis_object_expanded,
-                                                                                                                                                                                         class_label);
-                                               if (maybe_rt) {
-                                                       auto *statfile_rt = maybe_rt.value();
-
-                                                       /* Extract learned count using class index (1-based for Lua) */
-                                                       lua_rawgeti(L, 4, class_idx + 1);
-                                                       if (lua_isnumber(L, -1)) {
-                                                               statfile_rt->learned = lua_tointeger(L, -1);
-                                                               msg_debug_bayes("set learned count for class %s (label %s): %L",
-                                                                                               class_name, class_label, statfile_rt->learned);
-                                                       }
-                                                       lua_pop(L, 1); /* Pop learned_counts[class_idx + 1] */
+                                       const char *class_label = get_class_label(stcf);
+
+                                       /* Get the runtime for this statfile */
+                                       auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(rt->task,
+                                                                                                                                                                                 rt->redis_object_expanded,
+                                                                                                                                                                                 class_label);
+                                       if (maybe_rt) {
+                                               auto *statfile_rt = maybe_rt.value();
+
+                                               /* Extract learned count using statfile index (1-based for Lua) */
+                                               lua_rawgeti(L, 4, statfile_idx + 1);
+                                               if (lua_isnumber(L, -1)) {
+                                                       statfile_rt->learned = lua_tointeger(L, -1);
+                                                       msg_debug_bayes("set learned count for statfile %s (label %s): %L",
+                                                                                       stcf->symbol, class_label, statfile_rt->learned);
                                                }
-                                               break; /* Found the statfile for this class */
+                                               lua_pop(L, 1); /* Pop learned_counts[statfile_idx + 1] */
                                        }
                                        cur = g_list_next(cur);
+                                       statfile_idx++;
                                }
                        }
                }
 
-               /* Process token results using class_names order */
-               if (lua_istable(L, 5) && rt->stcf->clcf && rt->stcf->clcf->class_names &&
-                       rt->stcf->clcf->class_names->len > 0) {
-                       /* Process each class in the same order as sent to Redis */
-                       for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) {
-                               const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx);
+               /* Process token results */
+               if (lua_istable(L, 5) && rt->stcf->clcf) {
+                       if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) {
+                               /* Multi-class: use class_names order */
+                               for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) {
+                                       const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx);
+
+                                       /* Find statfile with this class name */
+                                       GList *cur = rt->stcf->clcf->statfiles;
+                                       while (cur) {
+                                               auto *stcf = (struct rspamd_statfile_config *) cur->data;
+                                               if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) {
+                                                       const char *class_label = get_class_label(stcf);
+
+                                                       /* Find the statfile ID */
+                                                       struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+                                                       struct rspamd_statfile *st = nullptr;
+                                                       for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+                                                               struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+                                                               if (candidate->stcf == stcf) {
+                                                                       st = candidate;
+                                                                       break;
+                                                               }
+                                                       }
 
-                               /* Find statfile with this class name */
-                               GList *cur = rt->stcf->clcf->statfiles;
-                               while (cur) {
-                                       auto *stcf = (struct rspamd_statfile_config *) cur->data;
-                                       if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) {
-                                               const char *class_label = get_class_label(stcf);
-
-                                               /* Find the statfile ID */
-                                               struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
-                                               struct rspamd_statfile *st = nullptr;
-                                               for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
-                                                       struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
-                                                       if (candidate->stcf == stcf) {
-                                                               st = candidate;
+                                                       if (!st) {
+                                                               msg_debug_bayes("statfile not found for class %s, skipping", class_name);
                                                                break;
                                                        }
-                                               }
-
-                                               if (!st) {
-                                                       msg_debug_bayes("statfile not found for class %s, skipping", class_name);
-                                                       break;
-                                               }
 
-                                               /* Get or create runtime for this statfile */
-                                               auto *statfile_rt = rt; /* Use current runtime if it matches */
-                                               if (stcf != rt->stcf) {
-                                                       auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
-                                                                                                                                                                                                 rt->redis_object_expanded,
-                                                                                                                                                                                                 class_label);
-                                                       if (maybe_rt) {
-                                                               statfile_rt = maybe_rt.value();
+                                                       /* Get or create runtime for this statfile */
+                                                       auto *statfile_rt = rt; /* Use current runtime if it matches */
+                                                       if (stcf != rt->stcf) {
+                                                               auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+                                                                                                                                                                                                         rt->redis_object_expanded,
+                                                                                                                                                                                                         class_label);
+                                                               if (maybe_rt) {
+                                                                       statfile_rt = maybe_rt.value();
+                                                               }
+                                                               else {
+                                                                       msg_debug_bayes("runtime not found for class %s, skipping", class_label);
+                                                                       break;
+                                                               }
                                                        }
-                                                       else {
-                                                               msg_debug_bayes("runtime not found for class %s, skipping", class_label);
-                                                               break;
+
+                                                       /* Ensure correct statfile ID assignment */
+                                                       statfile_rt->id = st->id;
+
+                                                       /* Process token results using class index (1-based for Lua) */
+                                                       lua_rawgeti(L, 5, class_idx + 1); /* Get token_results[class_idx + 1] */
+                                                       if (lua_istable(L, -1)) {
+                                                               /* Parse token results into statfile runtime */
+                                                               auto *res = new std::vector<std::pair<int, float>>();
+
+                                                               lua_pushnil(L); /* First key for iteration */
+                                                               while (lua_next(L, -2) != 0) {
+                                                                       if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) {
+                                                                               lua_rawgeti(L, -1, 1); /* token_index */
+                                                                               lua_rawgeti(L, -2, 2); /* token_count */
+
+                                                                               if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) {
+                                                                                       int token_idx = lua_tointeger(L, -2);
+                                                                                       float token_count = lua_tonumber(L, -1);
+                                                                                       res->emplace_back(token_idx, token_count);
+                                                                               }
+
+                                                                               lua_pop(L, 2); /* Pop token_index and token_count */
+                                                                       }
+                                                                       lua_pop(L, 1); /* Pop value, keep key for next iteration */
+                                                               }
+
+                                                               statfile_rt->set_results(res);
                                                        }
+                                                       lua_pop(L, 1); /* Pop token_results[class_idx + 1] */
+                                                       break;         /* Found the statfile for this class */
+                                               }
+                                               cur = g_list_next(cur);
+                                       }
+                               }
+                       }
+                       else {
+                               /* Binary classification: process statfiles in order */
+                               GList *cur = rt->stcf->clcf->statfiles;
+                               unsigned int statfile_idx = 0;
+                               while (cur) {
+                                       auto *stcf = (struct rspamd_statfile_config *) cur->data;
+                                       const char *class_label = get_class_label(stcf);
+
+                                       /* Find the statfile ID */
+                                       struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+                                       struct rspamd_statfile *st = nullptr;
+                                       for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+                                               struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+                                               if (candidate->stcf == stcf) {
+                                                       st = candidate;
+                                                       break;
                                                }
+                                       }
 
-                                               /* Ensure correct statfile ID assignment */
-                                               statfile_rt->id = st->id;
-
-                                               /* Process token results using class index (1-based for Lua) */
-                                               lua_rawgeti(L, 5, class_idx + 1); /* Get token_results[class_idx + 1] */
-                                               if (lua_istable(L, -1)) {
-                                                       /* Parse token results into statfile runtime */
-                                                       auto *res = new std::vector<std::pair<int, float>>();
-
-                                                       lua_pushnil(L); /* First key for iteration */
-                                                       while (lua_next(L, -2) != 0) {
-                                                               if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) {
-                                                                       lua_rawgeti(L, -1, 1); /* token_index */
-                                                                       lua_rawgeti(L, -2, 2); /* token_count */
-
-                                                                       if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) {
-                                                                               int token_idx = lua_tointeger(L, -2);
-                                                                               float token_count = lua_tonumber(L, -1);
-                                                                               res->emplace_back(token_idx, token_count);
-                                                                       }
+                                       if (!st) {
+                                               msg_debug_bayes("statfile not found for %s, skipping", stcf->symbol);
+                                               cur = g_list_next(cur);
+                                               statfile_idx++;
+                                               continue;
+                                       }
 
-                                                                       lua_pop(L, 2); /* Pop token_index and token_count */
+                                       /* Get or create runtime for this statfile */
+                                       auto *statfile_rt = rt; /* Use current runtime if it matches */
+                                       if (stcf != rt->stcf) {
+                                               auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+                                                                                                                                                                                         rt->redis_object_expanded,
+                                                                                                                                                                                         class_label);
+                                               if (maybe_rt) {
+                                                       statfile_rt = maybe_rt.value();
+                                               }
+                                               else {
+                                                       msg_debug_bayes("runtime not found for %s, skipping", class_label);
+                                                       cur = g_list_next(cur);
+                                                       statfile_idx++;
+                                                       continue;
+                                               }
+                                       }
+
+                                       /* Ensure correct statfile ID assignment */
+                                       statfile_rt->id = st->id;
+
+                                       /* Process token results using statfile index (1-based for Lua) */
+                                       lua_rawgeti(L, 5, statfile_idx + 1); /* Get token_results[statfile_idx + 1] */
+                                       if (lua_istable(L, -1)) {
+                                               /* Parse token results into statfile runtime */
+                                               auto *res = new std::vector<std::pair<int, float>>();
+
+                                               lua_pushnil(L); /* First key for iteration */
+                                               while (lua_next(L, -2) != 0) {
+                                                       if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) {
+                                                               lua_rawgeti(L, -1, 1); /* token_index */
+                                                               lua_rawgeti(L, -2, 2); /* token_count */
+
+                                                               if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) {
+                                                                       int token_idx = lua_tointeger(L, -2);
+                                                                       float token_count = lua_tonumber(L, -1);
+                                                                       res->emplace_back(token_idx, token_count);
                                                                }
-                                                               lua_pop(L, 1); /* Pop value, keep key for next iteration */
-                                                       }
 
-                                                       statfile_rt->set_results(res);
+                                                               lua_pop(L, 2); /* Pop token_index and token_count */
+                                                       }
+                                                       lua_pop(L, 1); /* Pop value, keep key for next iteration */
                                                }
-                                               lua_pop(L, 1); /* Pop token_results[class_idx + 1] */
-                                               break;         /* Found the statfile for this class */
+
+                                               statfile_rt->set_results(res);
+                                               msg_debug_bayes("set %uz token results for statfile %s (label %s, id=%d)",
+                                                                               res->size(), stcf->symbol, class_label, st->id);
                                        }
+                                       lua_pop(L, 1); /* Pop token_results[statfile_idx + 1] */
+
                                        cur = g_list_next(cur);
+                                       statfile_idx++;
                                }
                        }
                }