]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Add specific calculations for binary classification case
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 12 Aug 2025 15:38:37 +0000 (16:38 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 12 Aug 2025 15:39:07 +0000 (16:39 +0100)
src/libstat/classifiers/bayes.c

index dbae98cc28bba4ec9e4786a478505b47fc857cf9..1d5bb2a6fd7a07fe92f48ac89fcc94ec9389faf4 100644 (file)
@@ -331,34 +331,58 @@ bayes_classify_token_multiclass(struct rspamd_classifier *ctx,
 
                w = (fw * total_count) / (1.0 + fw * total_count);
 
-               /* Apply multinomial model for each class */
-               for (j = 0; j < cl->num_classes; j++) {
-                       /* Skip classes with insufficient learns */
-                       if (ctx->cfg->min_learns > 0 && cl->class_learns[j] < ctx->cfg->min_learns) {
-                               continue;
+               if (cl->num_classes == 2) {
+                       /* Binary-compatible path: normalize per-token probabilities across the two classes */
+                       double f0 = (double) class_counts[0] / MAX(1.0, (double) cl->class_learns[0]);
+                       double f1 = (double) class_counts[1] / MAX(1.0, (double) cl->class_learns[1]);
+                       double denom = f0 + f1;
+
+                       if (denom > 0.0) {
+                               double p0 = f0 / denom;
+                               double p1 = f1 / denom;
+                               double bp0 = PROB_COMBINE(p0, total_count, w, 0.5);
+                               double bp1 = PROB_COMBINE(p1, total_count, w, 0.5);
+
+                               /* Bound and apply min strength (relative to 0.5 for binary) */
+                               bp0 = MAX(0.0, MIN(1.0, bp0));
+                               bp1 = MAX(0.0, MIN(1.0, bp1));
+
+                               if (fabs(bp0 - 0.5) >= ctx->cfg->min_prob_strength) {
+                                       cl->class_log_probs[0] += log(bp0);
+                               }
+                               if (fabs(bp1 - 0.5) >= ctx->cfg->min_prob_strength) {
+                                       cl->class_log_probs[1] += log(bp1);
+                               }
                        }
+               }
+               else {
+                       /* General multinomial model for N>2 classes */
+                       for (j = 0; j < cl->num_classes; j++) {
+                               /* Skip classes with insufficient learns */
+                               if (ctx->cfg->min_learns > 0 && cl->class_learns[j] < ctx->cfg->min_learns) {
+                                       continue;
+                               }
 
-                       double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]);
-                       double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes);
+                               double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]);
+                               double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes);
 
-                       /* Ensure probability is properly bounded [0, 1] */
-                       class_prob = MAX(0.0, MIN(1.0, class_prob));
+                               /* Ensure probability is properly bounded [0, 1] */
+                               class_prob = MAX(0.0, MIN(1.0, class_prob));
 
-                       /* Skip probabilities too close to uniform (1/num_classes) */
-                       double uniform_prior = 1.0 / cl->num_classes;
-                       if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) {
-                               continue;
-                       }
+                               /* Skip probabilities too close to uniform (1/num_classes) */
+                               double uniform_prior = 1.0 / cl->num_classes;
+                               if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) {
+                                       continue;
+                               }
 
-                       cl->class_log_probs[j] += log(class_prob);
+                               cl->class_log_probs[j] += log(class_prob);
+                       }
                }
 
                cl->processed_tokens++;
                if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
                        cl->text_tokens++;
                }
-
-               /* Per-token debug logging removed to reduce verbosity */
        }
 }