]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Fix ROC threshold calculation for ham/spam labels
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 20 Jan 2026 16:12:41 +0000 (16:12 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 20 Jan 2026 16:12:41 +0000 (16:12 +0000)
The ROC calculation was checking outputs[i][1] == 0 for ham samples,
but the ceb_neg cost function uses -1.0 for ham and 1.0 for spam.
Changed to check outputs[i][1] < 0 to correctly identify ham samples.

lualib/plugins/neural.lua

index 68bdb3c3dc698c02ee9eca5463f24f25a1369e50..000a3fc6c6fb9882ec80251261944743c92c5a97 100644 (file)
@@ -474,7 +474,8 @@ local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
   spam_count_ahead[n_samples + 1] = 0
 
   for i = n_samples, 1, -1 do
-    if outputs[i][1] == 0 then
+    -- Labels are -1.0 for ham and 1.0 for spam (ceb_neg cost function)
+    if outputs[i][1] < 0 then
       n_ham = n_ham + 1
       ham_count_ahead[i] = 1
       spam_count_ahead[i] = 0
@@ -489,7 +490,8 @@ local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
   end
 
   for i = 1, n_samples do
-    if outputs[i][1] == 0 then
+    -- Labels are -1.0 for ham and 1.0 for spam (ceb_neg cost function)
+    if outputs[i][1] < 0 then
       ham_count_behind[i] = 1
       spam_count_behind[i] = 0
     else