]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Fix binary classification and lua scripts
authorVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 23 Jul 2025 13:33:04 +0000 (14:33 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 23 Jul 2025 13:33:04 +0000 (14:33 +0100)
lualib/redis_scripts/bayes_classify.lua
src/libserver/cfg_file.h
src/libserver/cfg_rcl.cxx
src/libserver/cfg_utils.cxx
src/libstat/classifiers/bayes.c

index 923adcc5ad69c83b62f81fc044cfe98cd27bf2e4..d6132e631bed9073d59f235a04d50bc9b6db7724 100644 (file)
@@ -35,7 +35,7 @@ end
 
 -- Get token data for all classes (ordered)
 local token_results = {}
-for i, label in ipairs(class_labels) do
+for i, _ in ipairs(class_labels) do
   token_results[i] = {}
 end
 
@@ -54,7 +54,7 @@ if has_learns then
     local token_data = redis.call('HMGET', token, unpack(class_labels))
 
     if token_data then
-      for j, label in ipairs(class_labels) do
+      for j, _ in ipairs(class_labels) do
         local count = token_data[j]
         if count and tonumber(count) > 0 then
           table.insert(token_results[j], { i, tonumber(count) })
index 5aaaece3552555dd4dcb9933b26c43b6dcd56e06..9f83f80244c44075576bbedc81572b0191c4abe3 100644 (file)
@@ -142,6 +142,7 @@ struct rspamd_statfile_config {
        char *class_name;                      /**< class name for multi-class classification                   */
        unsigned int class_index;              /**< class index for O(1) lookup during classification   */
        gboolean is_spam;                      /**< DEPRECATED: spam flag - use class_name instead              */
+       gboolean is_spam_converted;            /**< TRUE if class_name was converted from is_spam flag  */
        struct rspamd_classifier_config *clcf; /**< parent pointer of classifier configuration                  */
        gpointer data;                         /**< opaque data                                                                                 */
 };
index 5afb467452bccd6675ba46cfaa5ecc900f1df42b..68b6460d891cea12a4593a03360823a852f1166d 100644 (file)
@@ -1215,11 +1215,13 @@ rspamd_rcl_statfile_handler(rspamd_mempool_t *pool, const ucl_object_t *obj,
                                                                                                 strlen(st->symbol), "spam", 4) != -1) {
                                st->is_spam = TRUE;
                                st->class_name = rspamd_mempool_strdup(pool, "spam");
+                               st->is_spam_converted = TRUE;
                        }
                        else if (rspamd_substring_search_caseless(st->symbol,
                                                                                                          strlen(st->symbol), "ham", 3) != -1) {
                                st->is_spam = FALSE;
                                st->class_name = rspamd_mempool_strdup(pool, "ham");
+                               st->is_spam_converted = TRUE;
                        }
                        else {
                                g_set_error(err,
@@ -1242,6 +1244,7 @@ rspamd_rcl_statfile_handler(rspamd_mempool_t *pool, const ucl_object_t *obj,
                        else {
                                st->class_name = rspamd_mempool_strdup(pool, "ham");
                        }
+                       st->is_spam_converted = TRUE;
                }
                /* If class field is present, it was already parsed by the default parser */
                return TRUE;
@@ -1439,31 +1442,60 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool,
 
        cfg->classifiers = g_list_prepend(cfg->classifiers, ccf);
 
-       /* Populate class_names array from statfiles */
+       /* Populate class_names array from statfiles - only for explicit multiclass configs */
        if (ccf->statfiles) {
                GList *cur = ccf->statfiles;
-               ccf->class_names = g_ptr_array_new();
+               gboolean has_explicit_classes = FALSE;
 
+               /* Check if any statfile uses explicit class declaration (not converted from is_spam) */
+               cur = ccf->statfiles;
                while (cur) {
                        struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
-                       if (stcf->class_name) {
-                               /* Check if class already exists */
-                               bool found = false;
-                               for (unsigned int i = 0; i < ccf->class_names->len; i++) {
-                                       if (strcmp((char *) g_ptr_array_index(ccf->class_names, i), stcf->class_name) == 0) {
-                                               stcf->class_index = i; /* Store the index for O(1) lookup */
-                                               found = true;
-                                               break;
+                       msg_debug("checking statfile %s: class_name=%s, is_spam_converted=%s",
+                                         stcf->symbol, stcf->class_name ? stcf->class_name : "NULL",
+                                         stcf->is_spam_converted ? "true" : "false");
+                       if (stcf->class_name && !stcf->is_spam_converted) {
+                               has_explicit_classes = TRUE;
+                               break;
+                       }
+                       cur = g_list_next(cur);
+               }
+
+               msg_debug("has_explicit_classes = %s", has_explicit_classes ? "true" : "false");
+
+               /* Only populate class_names for explicit multiclass configurations */
+               if (has_explicit_classes) {
+                       msg_debug("populating class_names for multiclass configuration");
+               }
+               else {
+                       msg_debug("skipping class_names population for binary configuration");
+               }
+
+               if (has_explicit_classes) {
+                       ccf->class_names = g_ptr_array_new();
+
+                       cur = ccf->statfiles;
+                       while (cur) {
+                               struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+                               if (stcf->class_name) {
+                                       /* Check if class already exists */
+                                       bool found = false;
+                                       for (unsigned int i = 0; i < ccf->class_names->len; i++) {
+                                               if (strcmp((char *) g_ptr_array_index(ccf->class_names, i), stcf->class_name) == 0) {
+                                                       stcf->class_index = i; /* Store the index for O(1) lookup */
+                                                       found = true;
+                                                       break;
+                                               }
                                        }
-                               }
 
-                               if (!found) {
-                                       /* Add new class */
-                                       stcf->class_index = ccf->class_names->len;
-                                       g_ptr_array_add(ccf->class_names, g_strdup(stcf->class_name));
+                                       if (!found) {
+                                               /* Add new class */
+                                               stcf->class_index = ccf->class_names->len;
+                                               g_ptr_array_add(ccf->class_names, g_strdup(stcf->class_name));
+                                       }
                                }
+                               cur = g_list_next(cur);
                        }
-                       cur = g_list_next(cur);
                }
        }
 
index c8c08343970a689efae191db3d3776ccded50fae..2533bd65e8a10b6c280abf881e501ce9c69b8436 100644 (file)
@@ -3181,18 +3181,41 @@ rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf, GError
                                 class_count);
        }
 
-       /* Initialize classifier class tracking */
-       if (ccf->class_names) {
-               g_ptr_array_unref(ccf->class_names);
-       }
-       ccf->class_names = g_ptr_array_new_with_free_func(g_free);
-
-       /* Populate class names array */
-       GHashTableIter iter;
-       gpointer key, value;
-       g_hash_table_iter_init(&iter, seen_classes);
-       while (g_hash_table_iter_next(&iter, &key, &value)) {
-               g_ptr_array_add(ccf->class_names, g_strdup((const char *) key));
+       /* Initialize classifier class tracking - only for explicit multiclass configurations */
+       gboolean has_explicit_classes = FALSE;
+
+       /* Check if any statfile uses explicit class declaration (not converted from is_spam) */
+       cur = ccf->statfiles;
+       while (cur) {
+               stcf = (struct rspamd_statfile_config *) cur->data;
+               if (stcf->class_name && !stcf->is_spam_converted) {
+                       has_explicit_classes = TRUE;
+                       break;
+               }
+               cur = g_list_next(cur);
+       }
+
+       /* Only populate class_names for explicit multiclass configurations */
+       if (has_explicit_classes) {
+               if (ccf->class_names) {
+                       g_ptr_array_unref(ccf->class_names);
+               }
+               ccf->class_names = g_ptr_array_new_with_free_func(g_free);
+
+               /* Populate class names array */
+               GHashTableIter iter;
+               gpointer key, value;
+               g_hash_table_iter_init(&iter, seen_classes);
+               while (g_hash_table_iter_next(&iter, &key, &value)) {
+                       g_ptr_array_add(ccf->class_names, g_strdup((const char *) key));
+               }
+       }
+       else {
+               /* Binary configuration - ensure class_names is NULL */
+               if (ccf->class_names) {
+                       g_ptr_array_unref(ccf->class_names);
+                       ccf->class_names = nullptr;
+               }
        }
 
        g_hash_table_destroy(seen_classes);
index 4d070ee2025b93d4a1f6cbc0c2c494b5b5963565..3fd7190aeb82de2ed0f1ffa8135912fe12ddb5fe 100644 (file)
@@ -620,18 +620,27 @@ bayes_classify(struct rspamd_classifier *ctx,
        g_assert(tokens != NULL);
 
        /* Check if this is a multi-class classifier */
+       msg_debug_bayes("classification check: class_names=%p, len=%uz",
+                                       ctx->cfg->class_names,
+                                       ctx->cfg->class_names ? ctx->cfg->class_names->len : 0);
+
        if (ctx->cfg->class_names && ctx->cfg->class_names->len >= 2) {
                /* Verify that at least one statfile has class_name set (indicating new multi-class config) */
                gboolean has_class_names = FALSE;
                for (i = 0; i < ctx->statfiles_ids->len; i++) {
                        int id = g_array_index(ctx->statfiles_ids, int, i);
                        struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id);
+                       msg_debug_bayes("checking statfile %s: class_name=%s, is_spam_converted=%s",
+                                                       st->stcf->symbol,
+                                                       st->stcf->class_name ? st->stcf->class_name : "NULL",
+                                                       st->stcf->is_spam_converted ? "true" : "false");
                        if (st->stcf->class_name) {
                                has_class_names = TRUE;
-                               break;
                        }
                }
 
+               msg_debug_bayes("has_class_names=%s", has_class_names ? "true" : "false");
+
                if (has_class_names) {
                        msg_debug_bayes("using multiclass classification with %ud classes",
                                                        (unsigned int) ctx->cfg->class_names->len);