From: Vsevolod Stakhov Date: Tue, 29 Jul 2025 08:16:42 +0000 (+0100) Subject: [Minor] Fix single class fallback X-Git-Tag: 3.13.0~38^2~4 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=e3e85f617fec88bdba1bc8706426a8a5ecd8e8c4;p=thirdparty%2Frspamd.git [Minor] Fix single class fallback --- diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index b80a7a2bd8..f355133d40 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -914,120 +914,229 @@ rspamd_redis_classified(lua_State *L) lua_rawgeti(L, 3, 1); /* learned_counts -> position 4 */ lua_rawgeti(L, 3, 2); /* token_results -> position 5 */ - /* First, process learned_counts using class_names order */ - if (lua_istable(L, 4) && rt->stcf->clcf && rt->stcf->clcf->class_names && - rt->stcf->clcf->class_names->len > 0) { - /* Process each class in the same order as sent to Redis */ - for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) { - const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx); - - /* Find statfile with this class name */ + /* First, process learned_counts */ + if (lua_istable(L, 4) && rt->stcf->clcf) { + if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) { + /* Multi-class: use class_names order */ + for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) { + const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx); + + /* Find statfile with this class name */ + GList *cur = rt->stcf->clcf->statfiles; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) { + const char *class_label = get_class_label(stcf); + + /* Get the runtime for this statfile */ + auto maybe_rt = redis_stat_runtime::maybe_recover_from_mempool(rt->task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + auto *statfile_rt = maybe_rt.value(); + + /* Extract learned count using class index (1-based for Lua) */ + lua_rawgeti(L, 4, class_idx + 1); + if (lua_isnumber(L, -1)) { + statfile_rt->learned = lua_tointeger(L, -1); + msg_debug_bayes("set learned count for class %s (label %s): %L", + class_name, class_label, statfile_rt->learned); + } + lua_pop(L, 1); /* Pop learned_counts[class_idx + 1] */ + } + break; /* Found the statfile for this class */ + } + cur = g_list_next(cur); + } + } + } + else { + /* Binary classification: process statfiles in order */ GList *cur = rt->stcf->clcf->statfiles; + unsigned int statfile_idx = 0; while (cur) { auto *stcf = (struct rspamd_statfile_config *) cur->data; - if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) { - const char *class_label = get_class_label(stcf); - - /* Get the runtime for this statfile */ - auto maybe_rt = redis_stat_runtime::maybe_recover_from_mempool(rt->task, - rt->redis_object_expanded, - class_label); - if (maybe_rt) { - auto *statfile_rt = maybe_rt.value(); - - /* Extract learned count using class index (1-based for Lua) */ - lua_rawgeti(L, 4, class_idx + 1); - if (lua_isnumber(L, -1)) { - statfile_rt->learned = lua_tointeger(L, -1); - msg_debug_bayes("set learned count for class %s (label %s): %L", - class_name, class_label, statfile_rt->learned); - } - lua_pop(L, 1); /* Pop learned_counts[class_idx + 1] */ + const char *class_label = get_class_label(stcf); + + /* Get the runtime for this statfile */ + auto maybe_rt = redis_stat_runtime::maybe_recover_from_mempool(rt->task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + auto *statfile_rt = maybe_rt.value(); + + /* Extract learned count using statfile index (1-based for Lua) */ + lua_rawgeti(L, 4, statfile_idx + 1); + if (lua_isnumber(L, -1)) { + statfile_rt->learned = lua_tointeger(L, -1); + msg_debug_bayes("set learned count for statfile %s (label %s): %L", + stcf->symbol, class_label, statfile_rt->learned); } - break; /* Found the statfile for this class */ + lua_pop(L, 1); /* Pop learned_counts[statfile_idx + 1] */ } cur = g_list_next(cur); + statfile_idx++; } } } - /* Process token results using class_names order */ - if (lua_istable(L, 5) && rt->stcf->clcf && rt->stcf->clcf->class_names && - rt->stcf->clcf->class_names->len > 0) { - /* Process each class in the same order as sent to Redis */ - for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) { - const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx); + /* Process token results */ + if (lua_istable(L, 5) && rt->stcf->clcf) { + if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) { + /* Multi-class: use class_names order */ + for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) { + const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx); + + /* Find statfile with this class name */ + GList *cur = rt->stcf->clcf->statfiles; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) { + const char *class_label = get_class_label(stcf); + + /* Find the statfile ID */ + struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); + struct rspamd_statfile *st = nullptr; + for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) { + struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i); + if (candidate->stcf == stcf) { + st = candidate; + break; + } + } - /* Find statfile with this class name */ - GList *cur = rt->stcf->clcf->statfiles; - while (cur) { - auto *stcf = (struct rspamd_statfile_config *) cur->data; - if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) { - const char *class_label = get_class_label(stcf); - - /* Find the statfile ID */ - struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); - struct rspamd_statfile *st = nullptr; - for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) { - struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i); - if (candidate->stcf == stcf) { - st = candidate; + if (!st) { + msg_debug_bayes("statfile not found for class %s, skipping", class_name); break; } - } - - if (!st) { - msg_debug_bayes("statfile not found for class %s, skipping", class_name); - break; - } - /* Get or create runtime for this statfile */ - auto *statfile_rt = rt; /* Use current runtime if it matches */ - if (stcf != rt->stcf) { - auto maybe_rt = redis_stat_runtime::maybe_recover_from_mempool(task, - rt->redis_object_expanded, - class_label); - if (maybe_rt) { - statfile_rt = maybe_rt.value(); + /* Get or create runtime for this statfile */ + auto *statfile_rt = rt; /* Use current runtime if it matches */ + if (stcf != rt->stcf) { + auto maybe_rt = redis_stat_runtime::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + statfile_rt = maybe_rt.value(); + } + else { + msg_debug_bayes("runtime not found for class %s, skipping", class_label); + break; + } } - else { - msg_debug_bayes("runtime not found for class %s, skipping", class_label); - break; + + /* Ensure correct statfile ID assignment */ + statfile_rt->id = st->id; + + /* Process token results using class index (1-based for Lua) */ + lua_rawgeti(L, 5, class_idx + 1); /* Get token_results[class_idx + 1] */ + if (lua_istable(L, -1)) { + /* Parse token results into statfile runtime */ + auto *res = new std::vector>(); + + lua_pushnil(L); /* First key for iteration */ + while (lua_next(L, -2) != 0) { + if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) { + lua_rawgeti(L, -1, 1); /* token_index */ + lua_rawgeti(L, -2, 2); /* token_count */ + + if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) { + int token_idx = lua_tointeger(L, -2); + float token_count = lua_tonumber(L, -1); + res->emplace_back(token_idx, token_count); + } + + lua_pop(L, 2); /* Pop token_index and token_count */ + } + lua_pop(L, 1); /* Pop value, keep key for next iteration */ + } + + statfile_rt->set_results(res); } + lua_pop(L, 1); /* Pop token_results[class_idx + 1] */ + break; /* Found the statfile for this class */ + } + cur = g_list_next(cur); + } + } + } + else { + /* Binary classification: process statfiles in order */ + GList *cur = rt->stcf->clcf->statfiles; + unsigned int statfile_idx = 0; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + const char *class_label = get_class_label(stcf); + + /* Find the statfile ID */ + struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); + struct rspamd_statfile *st = nullptr; + for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) { + struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i); + if (candidate->stcf == stcf) { + st = candidate; + break; } + } - /* Ensure correct statfile ID assignment */ - statfile_rt->id = st->id; - - /* Process token results using class index (1-based for Lua) */ - lua_rawgeti(L, 5, class_idx + 1); /* Get token_results[class_idx + 1] */ - if (lua_istable(L, -1)) { - /* Parse token results into statfile runtime */ - auto *res = new std::vector>(); - - lua_pushnil(L); /* First key for iteration */ - while (lua_next(L, -2) != 0) { - if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) { - lua_rawgeti(L, -1, 1); /* token_index */ - lua_rawgeti(L, -2, 2); /* token_count */ - - if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) { - int token_idx = lua_tointeger(L, -2); - float token_count = lua_tonumber(L, -1); - res->emplace_back(token_idx, token_count); - } + if (!st) { + msg_debug_bayes("statfile not found for %s, skipping", stcf->symbol); + cur = g_list_next(cur); + statfile_idx++; + continue; + } - lua_pop(L, 2); /* Pop token_index and token_count */ + /* Get or create runtime for this statfile */ + auto *statfile_rt = rt; /* Use current runtime if it matches */ + if (stcf != rt->stcf) { + auto maybe_rt = redis_stat_runtime::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + statfile_rt = maybe_rt.value(); + } + else { + msg_debug_bayes("runtime not found for %s, skipping", class_label); + cur = g_list_next(cur); + statfile_idx++; + continue; + } + } + + /* Ensure correct statfile ID assignment */ + statfile_rt->id = st->id; + + /* Process token results using statfile index (1-based for Lua) */ + lua_rawgeti(L, 5, statfile_idx + 1); /* Get token_results[statfile_idx + 1] */ + if (lua_istable(L, -1)) { + /* Parse token results into statfile runtime */ + auto *res = new std::vector>(); + + lua_pushnil(L); /* First key for iteration */ + while (lua_next(L, -2) != 0) { + if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) { + lua_rawgeti(L, -1, 1); /* token_index */ + lua_rawgeti(L, -2, 2); /* token_count */ + + if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) { + int token_idx = lua_tointeger(L, -2); + float token_count = lua_tonumber(L, -1); + res->emplace_back(token_idx, token_count); } - lua_pop(L, 1); /* Pop value, keep key for next iteration */ - } - statfile_rt->set_results(res); + lua_pop(L, 2); /* Pop token_index and token_count */ + } + lua_pop(L, 1); /* Pop value, keep key for next iteration */ } - lua_pop(L, 1); /* Pop token_results[class_idx + 1] */ - break; /* Found the statfile for this class */ + + statfile_rt->set_results(res); + msg_debug_bayes("set %uz token results for statfile %s (label %s, id=%d)", + res->size(), stcf->symbol, class_label, st->id); } + lua_pop(L, 1); /* Pop token_results[statfile_idx + 1] */ + cur = g_list_next(cur); + statfile_idx++; } } }