]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Further fixes to rescore tool
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 30 May 2018 13:54:41 +0000 (14:54 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 30 May 2018 13:54:41 +0000 (14:54 +0100)
lualib/rescore_utility.lua
lualib/rspamadm/rescore.lua

index 39ad63365d4305454805ac11f7cd48e815d41670..195ae4364c9a53498c2798a6ae743240ed94e639 100644 (file)
@@ -11,7 +11,7 @@ function utility.get_all_symbols(logs, ignore_symbols)
 
   for _, line in pairs(logs) do
     line = lua_util.rspamd_str_split(line, " ")
-    for i=4,(#line-2) do
+    for i=4,(#line-1) do
       line[i] = line[i]:gsub("%s+", "")
       if not symbols_set[line[i]] then
         symbols_set[line[i]] = true
@@ -35,16 +35,23 @@ end
 function utility.read_log_file(file)
 
   local lines = {}
+  local messages = {}
 
-  file = assert(io.open(file, "r"))
+  local fd = assert(io.open(file, "r"))
+  local fname = string.gsub(file, "(.*/)(.*)", "%2")
 
-  for line in file:lines() do
-    lines[#lines + 1] = line
-  end
+  for line in fd:lines() do
+    local start,stop = string.find(line, fname .. ':')
+
+    if start and stop then
+      table.insert(lines, string.sub(line, 1, start))
+      table.insert(messages, string.sub(line, stop + 1, -1))
+    end
+end
 
-  io.close(file)
+  io.close(fd)
 
-  return lines
+  return lines,messages
 end
 
 function utility.get_all_logs(dirs)
@@ -55,26 +62,29 @@ function utility.get_all_logs(dirs)
   end
 
   local all_logs = {}
+  local all_messages = {}
 
   for _,dir in ipairs(dirs) do
     if dir:sub(-1, -1) == "/" then
       dir = dir:sub(1, -2)
       local files = rspamd_util.glob(dir .. "/*.log")
       for _, file in pairs(files) do
-        local logs = utility.read_log_file(file)
-        for _, log_line in pairs(logs) do
-          table.insert(all_logs, log_line)
+        local logs,messages = utility.read_log_file(file)
+        for i=1,#logs do
+          table.insert(all_logs, logs[i])
+          table.insert(all_messages, messages[i])
         end
       end
     else
-      local logs = utility.read_log_file(dir)
-      for _, log_line in pairs(logs) do
-        table.insert(all_logs, log_line)
+      local logs,messages = utility.read_log_file(dir)
+      for i=1,#logs do
+        table.insert(all_logs, logs[i])
+        table.insert(all_messages, messages[i])
       end
     end
   end
 
-  return all_logs
+  return all_logs,all_messages
 end
 
 function utility.get_all_symbol_scores(conf, ignore_symbols)
@@ -87,7 +97,7 @@ function utility.get_all_symbol_scores(conf, ignore_symbols)
   end, symbols)))
 end
 
-function utility.generate_statistics_from_logs(logs, threshold)
+function utility.generate_statistics_from_logs(logs, messages, threshold)
 
   -- Returns file_stats table and list of symbol_stats table.
 
@@ -120,9 +130,10 @@ function utility.generate_statistics_from_logs(logs, threshold)
   local no_of_spam = 0
   local no_of_ham = 0
 
-  for _, log in pairs(logs) do
+  for i, log in ipairs(logs) do
     log = lua_util.rspamd_str_trim(log)
     log = lua_util.rspamd_str_split(log, " ")
+    local message = messages[i]
 
     local is_spam = (log[1] == "SPAM")
     local score = tonumber(log[2])
@@ -139,40 +150,38 @@ function utility.generate_statistics_from_logs(logs, threshold)
       true_positives = true_positives + 1
     elseif is_spam and (score < threshold) then
       false_negatives = false_negatives + 1
-      table.insert(all_fns, log[#log])
+      table.insert(all_fns, message)
     elseif not is_spam and (score >= threshold) then
       false_positives = false_positives + 1
-      table.insert(all_fps, log[#log])
+      table.insert(all_fps, message)
     else
       true_negatives = true_negatives + 1
     end
 
-    for i=4, (#log-2) do
-      if all_symbols_stats[log[i]] == nil then
-        all_symbols_stats[log[i]] = {
-          name = log[i],
+    for j=4, (#log-1) do
+      if all_symbols_stats[log[j]] == nil then
+        all_symbols_stats[log[j]] = {
+          name = message,
           no_of_hits = 0,
           spam_hits = 0,
           ham_hits = 0,
           spam_overall = 0
         }
       end
+      local sym = log[j]
 
-      all_symbols_stats[log[i]].no_of_hits =
-      all_symbols_stats[log[i]].no_of_hits + 1
+      all_symbols_stats[sym].no_of_hits = all_symbols_stats[sym].no_of_hits + 1
 
       if is_spam then
-        all_symbols_stats[log[i]].spam_hits =
-        all_symbols_stats[log[i]].spam_hits + 1
+        all_symbols_stats[sym].spam_hits = all_symbols_stats[sym].spam_hits + 1
       else
-        all_symbols_stats[log[i]].ham_hits =
-        all_symbols_stats[log[i]].ham_hits + 1
+        all_symbols_stats[sym].ham_hits = all_symbols_stats[sym].ham_hits + 1
       end
 
       -- Find slowest message
-      if ((tonumber(log[#log-1]) or 0) > file_stats.slowest) then
-          file_stats.slowest = tonumber(log[#log-1])
-          file_stats.slowest_file = log[#log]
+      if ((tonumber(log[#log]) or 0) > file_stats.slowest) then
+          file_stats.slowest = tonumber(log[#log])
+          file_stats.slowest_file = message
       end
     end
   end
index 80b9630f457b0e8344d16a72559e9c2d708de786..cc331c6e85fc338192044dce5ef34bfbb555aa93 100644 (file)
@@ -188,17 +188,18 @@ local function init_weights(all_symbols, original_symbol_scores)
   return weights
 end
 
-local function shuffle(logs)
+local function shuffle(logs, messages)
 
   local size = #logs
   for i = size, 1, -1 do
     local rand = math.random(size)
     logs[i], logs[rand] = logs[rand], logs[i]
+    messages[i], messages[rand] = messages[rand], messages[i]
   end
 
 end
 
-local function split_logs(logs, split_percent)
+local function split_logs(logs, messages, split_percent)
 
   if not split_percent then
     split_percent = 60
@@ -208,16 +209,20 @@ local function split_logs(logs, split_percent)
 
   local test_logs = {}
   local train_logs = {}
+  local test_messages = {}
+  local train_messages = {}
 
   for i=1,split_index do
-    train_logs[#train_logs + 1] = logs[i]
+    table.insert(train_logs, logs[i])
+    table.insert(train_messages, messages[i])
   end
 
   for i=split_index + 1, #logs do
-    test_logs[#test_logs + 1] = logs[i]
+    table.insert(test_logs, logs[i])
+    table.insert(test_messages, messages[i])
   end
 
-  return train_logs, test_logs
+  return {train_logs,train_messages}, {test_logs,test_messages}
 end
 
 local function stitch_new_scores(all_symbols, new_scores)
@@ -291,7 +296,10 @@ local function print_score_diff(new_symbol_scores, original_symbol_scores)
 
 end
 
-local function calculate_fscore_from_weights(logs, all_symbols, weights, threshold)
+local function calculate_fscore_from_weights(logs, messages,
+                                             all_symbols,
+                                             weights,
+                                             threshold)
 
   local new_symbol_scores = weights:clone()
 
@@ -300,14 +308,15 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho
   logs = update_logs(logs, new_symbol_scores)
 
   local file_stats, _, all_fps, all_fns =
-      rescore_utility.generate_statistics_from_logs(logs, threshold)
+      rescore_utility.generate_statistics_from_logs(logs, messages, threshold)
 
   return file_stats.fscore, all_fps, all_fns
 end
 
-local function print_stats(logs, threshold)
+local function print_stats(logs, messages, threshold)
 
-  local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+  local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs,
+      messages, threshold)
 
   local file_stat_format = [[
 F-score: %.2f
@@ -519,7 +528,7 @@ local function handler(args)
   end
 
   local threshold,reject_score = get_threshold()
-  local logs = rescore_utility.get_all_logs(opts['log'])
+  local logs,messages = rescore_utility.get_all_logs(opts['log'])
 
   if opts['ignore-symbol'] then
     local function add_ignore(s)
@@ -574,7 +583,9 @@ local function handler(args)
 
   -- Display hit frequencies
   if opts['freq'] then
-      local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold)
+      local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs,
+          messages,
+          threshold)
       local t = {}
       for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end
 
@@ -607,7 +618,7 @@ local function handler(args)
       end
 
       -- Print file statistics
-      print_stats(logs, threshold)
+      print_stats(logs, messages, threshold)
 
       -- Work out how many symbols weren't seen in the corpus
       local symbols_no_hits = {}
@@ -635,13 +646,13 @@ local function handler(args)
       return
   end
 
-  shuffle(logs)
+  shuffle(logs, messages)
   torch.setdefaulttensortype('torch.FloatTensor')
 
-  local train_logs, validation_logs = split_logs(logs, 70)
-  local cv_logs, test_logs = split_logs(validation_logs, 50)
+  local train_logs, validation_logs = split_logs(logs, messages,70)
+  local cv_logs, test_logs = split_logs(validation_logs[1], validation_logs[2], 50)
 
-  local dataset = make_dataset_from_logs(train_logs, all_symbols, reject_score)
+  local dataset = make_dataset_from_logs(train_logs[1], all_symbols, reject_score)
 
   -- Start of perceptron training
   local input_size = #all_symbols
@@ -675,7 +686,8 @@ local function handler(args)
             initial_weights)
       end
 
-      local fscore, fps, fns = calculate_fscore_from_weights(cv_logs,
+      local fscore, fps, fns = calculate_fscore_from_weights(cv_logs[1],
+          cv_logs[2],
           all_symbols,
           linear_module.weight[1],
           threshold)
@@ -710,13 +722,13 @@ local function handler(args)
 
   -- Pre-rescore test stats
   logger.message("\n\nPre-rescore test stats\n")
-  test_logs = update_logs(test_logs, original_symbol_scores)
-  print_stats(test_logs, threshold)
+  test_logs[1] = update_logs(test_logs[1], original_symbol_scores)
+  print_stats(test_logs[1], test_logs[2], threshold)
 
   -- Post-rescore test stats
-  test_logs = update_logs(test_logs, new_symbol_scores)
+  test_logs[1] = update_logs(test_logs[1], new_symbol_scores)
   logger.message("\n\nPost-rescore test stats\n")
-  print_stats(test_logs, threshold)
+  print_stats(test_logs[1], test_logs[2], threshold)
 
   logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s',
       best_fscore, best_learning_rate, best_weight_decay)