]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Further fixes to rescore tool
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 7 Mar 2018 13:59:27 +0000 (13:59 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 7 Mar 2018 13:59:27 +0000 (13:59 +0000)
lualib/rspamadm/rescore.lua

index 16b7cf8b8d377f5f5830f481a41905a773c686db..ae61a58ba020b41af538901393b9e983d8c81150 100644 (file)
@@ -15,7 +15,7 @@ local ignore_symbols = {
   ['DATE_IN_FUTURE'] = true,
 }
 
-local function make_dataset_from_logs(logs, all_symbols)
+local function make_dataset_from_logs(logs, all_symbols, spam_score)
   -- Returns a list of {input, output} for torch SGD train
 
   local dataset = {}
@@ -125,7 +125,7 @@ local function update_logs(logs, symbol_scores)
 
     for j=4,#log do
       log[j] = log[j]:gsub("%s+", "")
-      score = score + (symbol_scores[log[j     ]] or 0)
+      score = score + (symbol_scores[log[j]] or 0)
     end
 
     log[2] = lua_util.round(score, 2)
@@ -174,7 +174,7 @@ local function print_score_diff(new_symbol_scores, original_symbol_scores)
 
 end
 
-local function calculate_fscore_from_weights(logs, all_symbols, weights, bias, threshold)
+local function calculate_fscore_from_weights(logs, all_symbols, weights, threshold)
 
   local new_symbol_scores = weights:clone()
 
@@ -210,7 +210,7 @@ end
 
 -- training function
 local function train(dataset, opt, model, criterion, epoch,
-                     all_symbols)
+                     all_symbols, spam_threshold)
   -- epoch tracker
   epoch = epoch or 1
 
@@ -284,9 +284,10 @@ local function train(dataset, opt, model, criterion, epoch,
 
       -- update confusion
       for i = 1,(last - t + 1) do
-        local class_predicted = 0
-        if outputs[i][1] > 0.5 then class_predicted = 1 end
-        confusion:add(class_predicted + 1, targets[i] + 1)
+        local class_predicted, target_class = 1, 1
+        if outputs[i][1] > 0.5 then class_predicted = 2 end
+        if targets[i] > 0.5 then target_class = 2 end
+        confusion:add(class_predicted, target_class)
       end
 
       -- return f and df/dX
@@ -395,16 +396,16 @@ local function get_threshold()
   local actions = rspamd_config:get_all_actions()
 
   if opts['spam-action'] then
-    return actions[opts['spam-action']] or 0
-  else
-    return actions['add header'] or actions['rewrite subject'] or actions['reject']
+    return (actions[opts['spam-action']] or 0),actions['reject']
   end
+  return (actions['add header'] or actions['rewrite subject']
+      or actions['reject']), actions['reject']
 end
 
 return function (args, cfg)
   opts = default_opts
   override_defaults(opts, getopt.getopt(args, 'i:'))
-  local threshold = get_threshold()
+  local threshold,reject_score = get_threshold()
   local logs = rescore_utility.get_all_logs(cfg["logdir"])
 
   if opts['ignore-symbol'] then
@@ -466,22 +467,22 @@ return function (args, cfg)
   local train_logs, validation_logs = split_logs(logs, 70)
   local cv_logs, test_logs = split_logs(validation_logs, 50)
 
-  local dataset = make_dataset_from_logs(train_logs, all_symbols)
+  local dataset = make_dataset_from_logs(train_logs, all_symbols, reject_score)
 
 
   -- Start of perceptron training
   local input_size = #all_symbols
   torch.setnumthreads(opts['threads'])
 
-  local linear_module = nn.Linear(input_size, 1)
-  local activation = nn.Tanh()
+  local linear_module = nn.Linear(input_size, 1, false)
+  local activation = nn.Sigmoid()
 
   local perceptron = nn.Sequential()
   perceptron:add(linear_module)
   perceptron:add(activation)
 
   local criterion = nn.MSECriterion()
-  criterion.sizeAverage = false
+  --criterion.sizeAverage = false
 
   local best_fscore = -math.huge
   local best_weights = linear_module.weight[1]:clone()
@@ -494,13 +495,12 @@ return function (args, cfg)
       opts.learning_rate = lr
       opts.weight_decay = wd
       for i=1,tonumber(opts.iters) do
-        train(dataset, opts, perceptron, criterion, i, all_symbols)
+        train(dataset, opts, perceptron, criterion, i, all_symbols, threshold)
       end
 
       local fscore = calculate_fscore_from_weights(cv_logs,
           all_symbols,
           linear_module.weight[1],
-          linear_module.bias[1],
           threshold)
 
       logger.messagex("Cross-validation fscore=%s, learning rate=%s, weight decay=%s",