]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Add timings
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 11 Jun 2024 13:32:13 +0000 (14:32 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 11 Jun 2024 13:32:13 +0000 (14:32 +0100)
lualib/rspamadm/classifier_test.lua

index fff4be444035c4d3bfd026b7da37fbd8a8f74ece..7bb9a22e654c7de540b5902d3de1061142496cbf 100644 (file)
@@ -111,7 +111,8 @@ local function classify_files(files)
 end
 
 -- Function to evaluate classifier performance
-local function evaluate_results(results, spam_label, ham_label, known_spam_files, known_ham_files, total_cv_files)
+local function evaluate_results(results, spam_label, ham_label,
+                                known_spam_files, known_ham_files, total_cv_files, elapsed)
   local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0
   for _, res in ipairs(results) do
     if res.result == spam_label then
@@ -146,7 +147,8 @@ local function evaluate_results(results, spam_label, ham_label, known_spam_files
   print(string.format("%-20s %-10.2f", "Precision", precision))
   print(string.format("%-20s %-10.2f", "Recall", recall))
   print(string.format("%-20s %-10.2f", "F1 Score", f1_score))
-  print(string.format("%-20s %-10.2f%%", "Classified", total / total_cv_files * 100))
+  print(string.format("%-20s %-10.2f", "Classified (%)", total / total_cv_files * 100))
+  print(string.format("%-20s %-10.2f", "Elapsed time (seconds)", elapsed))
 end
 
 local function handler(args)
@@ -168,11 +170,17 @@ local function handler(args)
       #train_spam, #cv_spam, #train_ham, #cv_ham))
   if not opts.no_learning then
     -- Train classifier
+    local t, train_spam_time, train_ham_time
     print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns))
+    t = rspamd_util.get_time()
     train_classifier(train_spam, "learn_spam")
+    train_spam_time = rspamd_util.get_time() - t
     print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns))
+    t = rspamd_util.get_time()
     train_classifier(train_ham, "learn_ham")
-    print("Learning done")
+    train_ham_time = rspamd_util.get_time() - t
+    print(string.format("Learning done: %d spam messages in %.2f seconds, %d ham messages in %.2f seconds",
+        #train_spam, train_spam_time, #train_ham, train_ham_time))
   end
 
   -- Classify cross-validation files
@@ -189,10 +197,16 @@ local function handler(args)
 
   print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns))
   -- Get classification results
+  local t = rspamd_util.get_time()
   local results = classify_files(cv_files)
+  local elapsed = rspamd_util.get_time() - t
 
   -- Evaluate results
-  evaluate_results(results, "spam", "ham", known_spam_files, known_ham_files, #cv_files)
+  evaluate_results(results, "spam", "ham",
+      known_spam_files,
+      known_ham_files,
+      #cv_files,
+      elapsed)
 
 end