From: Vsevolod Stakhov Date: Thu, 29 Apr 2021 18:44:40 +0000 (+0100) Subject: [Minor] Neural: Allow to have flat classification if needed X-Git-Tag: 3.0~450 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6244d64b43baa240d63528849a7a47b3f32eccc3;p=thirdparty%2Frspamd.git [Minor] Neural: Allow to have flat classification if needed --- diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index f0d5cf582c..5571335913 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -57,6 +57,7 @@ local default_options = { -- Check ROC curve and AUC in the ML literature spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable) ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable) + flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached symbol_spam = 'NEURAL_SPAM', symbol_ham = 'NEURAL_HAM', max_inputs = nil, -- when PCA is used diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index ca11d9e666..2ac8df59f3 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -121,7 +121,11 @@ local function ann_scores_filter(task) local result = score if not rule.spam_score_threshold or result >= rule.spam_score_threshold then - task:insert_result(rule.symbol_spam, result, symscore) + if rule.flat_threshold_curve then + task:insert_result(rule.symbol_spam, 1.0, symscore) + else + task:insert_result(rule.symbol_spam, result, symscore) + end else lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam_score_threshold)', rule.prefix, set.name, set.ann.version, symscore, @@ -131,7 +135,11 @@ local function ann_scores_filter(task) local result = -(score) if not rule.ham_score_threshold or result >= rule.ham_score_threshold then - task:insert_result(rule.symbol_ham, result, symscore) + if rule.flat_threshold_curve then + task:insert_result(rule.symbol_ham, 1.0, symscore) + else + task:insert_result(rule.symbol_ham, result, symscore) + end else lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham_score_threshold)', rule.prefix, set.name, set.ann.version, result,