]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add pending training keys and fix neural network training issues
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 19 Jan 2026 18:08:58 +0000 (18:08 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 19 Jan 2026 18:08:58 +0000 (18:08 +0000)
- Add pending_train_key() for version-independent training vector storage
- Fix variable shadowing bug where ann_trained callback was overwritten
- Add concurrent training prevention via learning_spawned check
- Replace assert with proper error handling for msgpack parsing
- Clean up pending keys after successful training
- Update controller endpoint to use pending keys for manual training
- Fix ev_base:sleep() to register with session events properly
- Update classifier_test.lua to support llm_embeddings classifier testing

Co-Authored-By: Claude <noreply@anthropic.com>
lualib/plugins/neural.lua
lualib/rspamadm/classifier_test.lua
rules/controller/neural.lua
src/lua/lua_util.c
src/plugins/lua/neural.lua

index 919df7f32ca8a82645018233626207fcac591e05..7f16618f19efaa0bb21b9d177eb748d9a481e819 100644 (file)
@@ -587,6 +587,13 @@ local function redis_ann_prefix(rule, settings_name)
     settings.prefix, plugin_ver, rule.prefix, n, settings_name)
 end
 
+-- Returns a stable key for pending training vectors (version-independent)
+-- Used for batch/manual training to avoid version mismatch issues
+local function pending_train_key(rule, set)
+  return string.format('%s_%s_%s_pending',
+    settings.prefix, rule.prefix, set.name)
+end
+
 -- Compute a stable digest for providers configuration
 local function providers_config_digest(providers_cfg, rule)
   if not providers_cfg then return nil end
@@ -816,6 +823,13 @@ end
 
 -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
 local function spawn_train(params)
+  -- Prevent concurrent training
+  if params.set.learning_spawned then
+    lua_util.debugm(N, rspamd_config, 'spawn_train: training already in progress for %s:%s, skipping',
+      params.rule.prefix, params.set.name)
+    return
+  end
+
   -- Check training data sanity
   -- Now we need to join inputs and create the appropriate test vectors
   local n
@@ -1006,6 +1020,28 @@ local function spawn_train(params)
       else
         rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
           params.rule.prefix, params.set.name, params.set.ann.redis_key)
+
+        -- Clean up pending training keys if they were used
+        if params.pending_key then
+          local function cleanup_cb(cleanup_err)
+            if cleanup_err then
+              lua_util.debugm(N, rspamd_config, 'failed to cleanup pending keys: %s', cleanup_err)
+            else
+              lua_util.debugm(N, rspamd_config, 'cleaned up pending training keys for %s',
+                params.pending_key)
+            end
+          end
+          -- Delete both spam and ham pending sets
+          lua_redis.redis_make_request_taskless(params.ev_base,
+            rspamd_config,
+            params.rule.redis,
+            nil,
+            true,       -- is write
+            cleanup_cb,
+            'DEL',
+            { params.pending_key .. '_spam_set', params.pending_key .. '_ham_set' }
+          )
+        end
       end
     end
 
@@ -1026,7 +1062,20 @@ local function spawn_train(params)
       else
         local parser = ucl.parser()
         local ok, parse_err = parser:parse_text(data, 'msgpack')
-        assert(ok, parse_err)
+        if not ok then
+          rspamd_logger.errx(rspamd_config, 'cannot parse training result for ANN %s:%s: %s (data size: %s)',
+            params.rule.prefix, params.set.name, parse_err, #data)
+          lua_redis.redis_make_request_taskless(params.ev_base,
+            rspamd_config,
+            params.rule.redis,
+            nil,
+            true,
+            gen_unlock_cb(params.rule, params.set, params.ann_key),
+            'HDEL',
+            { params.ann_key, 'lock' }
+          )
+          return
+        end
         local parsed = parser:get_object()
         local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
         local pca_data = parsed.pca_data
@@ -1045,10 +1094,10 @@ local function spawn_train(params)
 
 
         -- Deserialise ANN from the child process
-        ann_trained = rspamd_kann.load(parsed.ann_data)
+        local loaded_ann = rspamd_kann.load(parsed.ann_data)
         local version = (params.set.ann.version or 0) + 1
         params.set.ann.version = version
-        params.set.ann.ann = ann_trained
+        params.set.ann.ann = loaded_ann
         params.set.ann.symbols = params.set.symbols
         params.set.ann.redis_key = new_ann_key(params.rule, params.set, version)
 
@@ -1306,6 +1355,7 @@ return {
   load_scripts = load_scripts,
   module_config = module_config,
   new_ann_key = new_ann_key,
+  pending_train_key = pending_train_key,
   providers_config_digest = providers_config_digest,
   register_provider = register_provider,
   plugin_ver = plugin_ver,
index 4148a753880076f46d98fb2b8f2371b97daa6913..2f3a6a391044ee7a0a31491ea5f9dce05834c77a 100644 (file)
@@ -6,7 +6,7 @@ local rspamd_logger = require "rspamd_logger"
 
 local parser = argparse()
     :name "rspamadm classifier_test"
-    :description "Learn bayes classifier and evaluate its performance"
+    :description "Learn classifier and evaluate its performance"
     :help_description_margin(32)
 
 parser:option "-H --ham"
@@ -15,8 +15,14 @@ parser:option "-H --ham"
 parser:option "-S --spam"
       :description("Spam directory")
       :argname("<dir>")
+parser:option "-C --classifier"
+      :description("Classifier type: bayes or llm_embeddings")
+      :argname("<type>")
+      :default('bayes')
 parser:flag "-n --no-learning"
       :description("Do not learn classifier")
+parser:flag "-T --train-only"
+      :description("Only train, do not evaluate (llm_embeddings only)")
 parser:option "--nconns"
       :description("Number of parallel connections")
       :argname("<N>")
@@ -35,19 +41,22 @@ parser:option "-r --rspamc"
       :description("Use specific rspamc path")
       :argname("<path>")
       :default('rspamc')
-parser:option "-c --cv-fraction"
+parser:option "-f --cv-fraction"
       :description("Use specific fraction for cross-validation")
       :argname("<fraction>")
       :convert(tonumber)
-      :default('0.7')
+      :default(0.7)
 parser:option "--spam-symbol"
-      :description("Use specific spam symbol (instead of BAYES_SPAM)")
+      :description("Use specific spam symbol (auto-detected from classifier type)")
       :argname("<symbol>")
-      :default('BAYES_SPAM')
 parser:option "--ham-symbol"
-      :description("Use specific ham symbol (instead of BAYES_HAM)")
+      :description("Use specific ham symbol (auto-detected from classifier type)")
       :argname("<symbol>")
-      :default('BAYES_HAM')
+parser:option "--train-wait"
+      :description("Seconds to wait after training for neural network (llm_embeddings only, should be > watch_interval)")
+      :argname("<sec>")
+      :convert(tonumber)
+      :default(90)
 
 local opts
 
@@ -78,15 +87,48 @@ local function list_to_file(list, fname)
   out:close()
 end
 
--- Function to train the classifier with given files
-local function train_classifier(files, command)
+-- Function to train the Bayes classifier with given files
+local function train_bayes(files, command)
   local fname = os.tmpname()
   list_to_file(files, fname)
   local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s --files-list=%s",
       opts.rspamc, opts.connect, opts.nconns, opts.timeout, command, fname)
+  local handle = assert(io.popen(rspamc_command))
+  handle:read("*all")
+  handle:close()
+  os.remove(fname)
+end
+
+-- Function to train with ANN-Train header (for llm_embeddings/neural)
+-- Uses settings to enable only NEURAL_LEARN symbol, skipping full scan
+local function train_neural(files, learn_type)
+  local fname = os.tmpname()
+  list_to_file(files, fname)
+
+  -- Use ANN-Train header with settings to limit scan to NEURAL_LEARN only
+  local rspamc_command = string.format(
+    "%s --connect %s -j --compact -n %s -t %.3f " ..
+    "--settings '{\"symbols_enabled\":[\"NEURAL_LEARN\"]}' " ..
+    "--header 'ANN-Train=%s' --files-list=%s",
+    opts.rspamc, opts.connect, opts.nconns, opts.timeout,
+    learn_type, fname)
+
   local result = assert(io.popen(rspamc_command))
-  result = result:read("*all")
+  local output = result:read("*all")
+  result:close()
   os.remove(fname)
+
+  -- Count successful submissions
+  local count = 0
+  for line in output:gmatch("[^\n]+") do
+    local ucl_parser = ucl.parser()
+    local is_good, _ = ucl_parser:parse_string(line)
+    if is_good then
+      count = count + 1
+    end
+  end
+
+  return count
 end
 
 -- Function to classify files and return results
@@ -94,7 +136,7 @@ local function classify_files(files, known_spam_files, known_ham_files)
   local fname = os.tmpname()
   list_to_file(files, fname)
 
-  local settings_header = string.format('--header Settings=\"{symbols_enabled=[%s, %s]}\"',
+  local settings_header = string.format('--header Settings="{symbols_enabled=[%s, %s]}"',
       opts.spam_symbol, opts.ham_symbol)
   local rspamc_command = string.format("%s %s --connect %s --compact -n %s -t %.3f --files-list=%s",
       opts.rspamc,
@@ -109,26 +151,31 @@ local function classify_files(files, known_spam_files, known_ham_files)
     local is_good, err = ucl_parser:parse_string(line)
     if not is_good then
       rspamd_logger.errx("Parser error: %1", err)
-      os.remove(fname)
-      return nil
-    end
-    local obj = ucl_parser:get_object()
-    local file = obj.filename
-    local symbols = obj.symbols or {}
-
-    if symbols[opts.spam_symbol] then
-      table.insert(results, { result = "spam", file = file })
-      if known_ham_files[file] then
-        rspamd_logger.message("FP: %s is classified as spam but is known ham", file)
-      end
-    elseif symbols[opts.ham_symbol] then
-      if known_spam_files[file] then
-        rspamd_logger.message("FN: %s is classified as ham but is known spam", file)
+    else
+      local obj = ucl_parser:get_object()
+      local file = obj.filename
+      local symbols = obj.symbols or {}
+
+      if symbols[opts.spam_symbol] then
+        local score = symbols[opts.spam_symbol].score
+        table.insert(results, { result = "spam", file = file, score = score })
+        if known_ham_files[file] then
+          rspamd_logger.message("FP: %s is classified as spam but is known ham", file)
+        end
+      elseif symbols[opts.ham_symbol] then
+        local score = symbols[opts.ham_symbol].score
+        table.insert(results, { result = "ham", file = file, score = score })
+        if known_spam_files[file] then
+          rspamd_logger.message("FN: %s is classified as ham but is known spam", file)
+        end
+      else
+        -- No classification result
+        table.insert(results, { result = "unknown", file = file })
       end
-      table.insert(results, { result = "ham", file = file })
     end
   end
 
+  result:close()
   os.remove(fname)
 
   return results
@@ -137,7 +184,9 @@ end
 -- Function to evaluate classifier performance
 local function evaluate_results(results, spam_label, ham_label,
                                 known_spam_files, known_ham_files, total_cv_files, elapsed)
-  local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0
+  local true_positives, false_positives, true_negatives, false_negatives = 0, 0, 0, 0
+  local classified, unclassified = 0, 0
+
   for _, res in ipairs(results) do
     if res.result == spam_label then
       if known_spam_files[res.file] then
@@ -145,69 +194,174 @@ local function evaluate_results(results, spam_label, ham_label,
       elseif known_ham_files[res.file] then
         false_positives = false_positives + 1
       end
-      total = total + 1
+      classified = classified + 1
     elseif res.result == ham_label then
       if known_spam_files[res.file] then
         false_negatives = false_negatives + 1
       elseif known_ham_files[res.file] then
         true_negatives = true_negatives + 1
       end
-      total = total + 1
+      classified = classified + 1
+    else
+      unclassified = unclassified + 1
     end
   end
 
-  local accuracy = (true_positives + true_negatives) / total
-  local precision = true_positives / (true_positives + false_positives)
-  local recall = true_positives / (true_positives + false_negatives)
-  local f1_score = 2 * (precision * recall) / (precision + recall)
-
-  print(string.format("%-20s %-10s", "Metric", "Value"))
-  print(string.rep("-", 30))
+  print(string.format("\n%-20s %-10s", "Metric", "Value"))
+  print(string.rep("-", 35))
   print(string.format("%-20s %-10d", "True Positives", true_positives))
   print(string.format("%-20s %-10d", "False Positives", false_positives))
   print(string.format("%-20s %-10d", "True Negatives", true_negatives))
   print(string.format("%-20s %-10d", "False Negatives", false_negatives))
-  print(string.format("%-20s %-10.2f", "Accuracy", accuracy))
-  print(string.format("%-20s %-10.2f", "Precision", precision))
-  print(string.format("%-20s %-10.2f", "Recall", recall))
-  print(string.format("%-20s %-10.2f", "F1 Score", f1_score))
-  print(string.format("%-20s %-10.2f", "Classified (%)", total / total_cv_files * 100))
-  print(string.format("%-20s %-10.2f", "Elapsed time (seconds)", elapsed))
+  print(string.format("%-20s %-10d", "Unclassified", unclassified))
+
+  if classified > 0 then
+    local accuracy = (true_positives + true_negatives) / classified
+    local precision = true_positives > 0 and true_positives / (true_positives + false_positives) or 0
+    local recall = true_positives > 0 and true_positives / (true_positives + false_negatives) or 0
+    local f1_score = (precision + recall) > 0 and 2 * (precision * recall) / (precision + recall) or 0
+
+    print(string.format("%-20s %-10.4f", "Accuracy", accuracy))
+    print(string.format("%-20s %-10.4f", "Precision", precision))
+    print(string.format("%-20s %-10.4f", "Recall", recall))
+    print(string.format("%-20s %-10.4f", "F1 Score", f1_score))
+  end
+
+  print(string.format("%-20s %-10.2f%%", "Classified", classified / total_cv_files * 100))
+  print(string.format("%-20s %-10.2f", "Elapsed (sec)", elapsed))
 end
 
 local function handler(args)
   opts = parser:parse(args)
+
   local ham_directory = opts['ham']
   local spam_directory = opts['spam']
+  local classifier_type = opts['classifier']
+
+  if not ham_directory or not spam_directory then
+    print("Error: Both --ham and --spam directories are required")
+    os.exit(1)
+  end
+
+  -- Set default symbols based on classifier type
+  if not opts.spam_symbol then
+    if classifier_type == 'llm_embeddings' then
+      opts.spam_symbol = 'NEURAL_SPAM'
+    else
+      opts.spam_symbol = 'BAYES_SPAM'
+    end
+  end
+  if not opts.ham_symbol then
+    if classifier_type == 'llm_embeddings' then
+      opts.ham_symbol = 'NEURAL_HAM'
+    else
+      opts.ham_symbol = 'BAYES_HAM'
+    end
+  end
+
   -- Get all files
   local spam_files = get_files(spam_directory)
   local known_spam_files = lua_util.list_to_hash(spam_files)
   local ham_files = get_files(ham_directory)
   local known_ham_files = lua_util.list_to_hash(ham_files)
 
-  -- Split files into training and cross-validation sets
+  print(string.format("Classifier: %s", classifier_type))
+  print(string.format("Found %d spam files, %d ham files", #spam_files, #ham_files))
 
+  -- Split files into training and cross-validation sets
   local train_spam, cv_spam = split_table(spam_files, opts.cv_fraction)
   local train_ham, cv_ham = split_table(ham_files, opts.cv_fraction)
 
-  print(string.format("Spam: %d train files, %d cv files; ham: %d train files, %d cv files",
+  print(string.format("Split: %d/%d spam (train/test), %d/%d ham (train/test)",
       #train_spam, #cv_spam, #train_ham, #cv_ham))
+
+  -- Training phase
   if not opts.no_learning then
-    -- Train classifier
+    print("\n=== Training Phase ===")
+
     local t, train_spam_time, train_ham_time
-    print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns))
-    t = rspamd_util.get_time()
-    train_classifier(train_spam, "learn_spam")
-    train_spam_time = rspamd_util.get_time() - t
-    print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns))
-    t = rspamd_util.get_time()
-    train_classifier(train_ham, "learn_ham")
-    train_ham_time = rspamd_util.get_time() - t
-    print(string.format("Learning done: %d spam messages in %.2f seconds, %d ham messages in %.2f seconds",
-        #train_spam, train_spam_time, #train_ham, train_ham_time))
+
+    if classifier_type == 'llm_embeddings' then
+      -- Neural/LLM training using ANN-Train header
+      -- Interleave spam and ham submissions for balanced training
+      print(string.format("Training %d spam + %d ham messages (interleaved)...", #train_spam, #train_ham))
+      t = rspamd_util.get_time()
+
+      -- Create interleaved list of {file, type} pairs
+      local interleaved = {}
+      local spam_idx, ham_idx = 1, 1
+      while spam_idx <= #train_spam or ham_idx <= #train_ham do
+        if spam_idx <= #train_spam then
+          table.insert(interleaved, { file = train_spam[spam_idx], type = 'spam' })
+          spam_idx = spam_idx + 1
+        end
+        if ham_idx <= #train_ham then
+          table.insert(interleaved, { file = train_ham[ham_idx], type = 'ham' })
+          ham_idx = ham_idx + 1
+        end
+      end
+
+      -- Submit in batches, grouped by type for efficiency
+      local batch_size = math.max(1, math.floor(#interleaved / 10))
+      local spam_batch, ham_batch = {}, {}
+      local spam_trained, ham_trained = 0, 0
+
+      for i, item in ipairs(interleaved) do
+        if item.type == 'spam' then
+          table.insert(spam_batch, item.file)
+        else
+          table.insert(ham_batch, item.file)
+        end
+
+        -- Submit batches periodically
+        if i % batch_size == 0 or i == #interleaved then
+          if #spam_batch > 0 then
+            spam_trained = spam_trained + train_neural(spam_batch, "spam")
+            spam_batch = {}
+          end
+          if #ham_batch > 0 then
+            ham_trained = ham_trained + train_neural(ham_batch, "ham")
+            ham_batch = {}
+          end
+        end
+      end
+
+      train_spam_time = rspamd_util.get_time() - t
+      train_ham_time = 0 -- Combined time
+      print(string.format("  Submitted %d spam + %d ham samples in %.2f seconds",
+        spam_trained, ham_trained, train_spam_time))
+
+      -- Wait for neural network to train using ev_base sleep
+      print(string.format("\nWaiting %d seconds for neural network training...", opts.train_wait))
+      rspamadm_ev_base:sleep(opts.train_wait)
+      print("Training wait complete.")
+    else
+      -- Bayes training using learn_spam/learn_ham
+      print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns))
+      t = rspamd_util.get_time()
+      train_bayes(train_spam, "learn_spam")
+      train_spam_time = rspamd_util.get_time() - t
+
+      print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns))
+      t = rspamd_util.get_time()
+      train_bayes(train_ham, "learn_ham")
+      train_ham_time = rspamd_util.get_time() - t
+
+      print(string.format("Learning done: %d spam in %.2f sec, %d ham in %.2f sec",
+          #train_spam, train_spam_time, #train_ham, train_ham_time))
+    end
+  else
+    print("\nSkipping training phase (--no-learning)")
+  end
+
+  if opts.train_only then
+    print("\nTraining only mode - skipping evaluation")
+    return
   end
 
-  -- Classify cross-validation files
+  -- Cross-validation phase
+  print("\n=== Evaluation Phase ===")
+
   local cv_files = {}
   for _, file in ipairs(cv_spam) do
     table.insert(cv_files, file)
@@ -219,7 +373,8 @@ local function handler(args)
   -- Shuffle cross-validation files
   cv_files = split_table(cv_files, 1)
 
-  print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns))
+  print(string.format("Classifying %d test messages...", #cv_files))
+
   -- Get classification results
   local t = rspamd_util.get_time()
   local results = classify_files(cv_files, known_spam_files, known_ham_files)
@@ -231,7 +386,6 @@ local function handler(args)
       known_ham_files,
       #cv_files,
       elapsed)
-
 end
 
 return {
index 628e4a62d78aa2346ea9c1ad1e689d5636e92f92..4c8c931d0f066b5679c3ca5e5b783642b4f1c954 100644 (file)
@@ -345,7 +345,9 @@ local function handle_learn_message(task, conn)
     lua_util.debugm(N, task, 'controller.neural: vector size=%s head=[%s]', vlen, vhead)
 
     local compressed = rspamd_util.zstd_compress(table.concat(vec, ';'))
-    local target_key = string.format('%s_%s_set', redis_base, learn_type)
+    -- Use pending key for manual training (picked up by training loop)
+    local pending_key = neural_common.pending_train_key(rule, set)
+    local target_key = string.format('%s_%s_set', pending_key, learn_type)
 
     local function learn_vec_cb(redis_err)
       if redis_err then
@@ -385,11 +387,94 @@ local function handle_train(task, conn, req_params)
     conn:send_error(400, 'unknown rule')
     return
   end
-  -- Trigger check_anns to evaluate training conditions
-  rspamd_config:add_periodic(task:get_ev_base(), 0.0, function()
-    return 0.0
-  end)
-  conn:send_ucl({ success = true, message = 'training scheduled check' })
+
+  -- Get the set for this rule
+  local set = neural_common.get_rule_settings(task, rule)
+  if not set then
+    -- Try to find any available set
+    for sid, s in pairs(rule.settings or {}) do
+      if type(s) == 'table' then
+        set = s
+        set.name = set.name or sid
+        break
+      end
+    end
+  end
+
+  if not set then
+    conn:send_error(400, 'no settings found for rule')
+    return
+  end
+
+  -- Check pending vectors count
+  local pending_key = neural_common.pending_train_key(rule, set)
+  local ev_base = task:get_ev_base()
+
+  local function check_and_train(spam_count, ham_count)
+    if spam_count > 0 and ham_count > 0 then
+      -- We have vectors in pending, find or create a profile and train
+      local ann_key
+      if set.ann and set.ann.redis_key then
+        ann_key = set.ann.redis_key
+      else
+        -- Create a new profile
+        ann_key = neural_common.new_ann_key(rule, set, 0)
+      end
+
+      rspamd_logger.infox(task, 'manual train trigger for %s:%s with %s spam, %s ham vectors',
+        rule.prefix, set.name, spam_count, ham_count)
+
+      -- The training will be picked up by the next check_anns cycle
+      -- For immediate training, we'd need access to the worker object
+      conn:send_ucl({
+        success = true,
+        message = 'training vectors available',
+        spam_vectors = spam_count,
+        ham_vectors = ham_count,
+        pending_key = pending_key
+      })
+    else
+      conn:send_ucl({
+        success = false,
+        message = 'not enough vectors for training',
+        spam_vectors = spam_count,
+        ham_vectors = ham_count
+      })
+    end
+  end
+
+  -- Count pending vectors
+  local spam_count = 0
+  local function count_ham_cb(err, data)
+    local ham_count = 0
+    if not err and (type(data) == 'number' or type(data) == 'string') then
+      ham_count = tonumber(data) or 0
+    end
+    check_and_train(spam_count, ham_count)
+  end
+
+  local function count_spam_cb(err, data)
+    if not err and (type(data) == 'number' or type(data) == 'string') then
+      spam_count = tonumber(data) or 0
+    end
+    lua_redis.redis_make_request(task,
+      rule.redis,
+      nil,
+      false,
+      count_ham_cb,
+      'SCARD',
+      { pending_key .. '_ham_set' }
+    )
+  end
+
+  lua_redis.redis_make_request(task,
+    rule.redis,
+    nil,
+    false,
+    count_spam_cb,
+    'SCARD',
+    { pending_key .. '_spam_set' }
+  )
 end
 
 return {
index 12dfee02ac7d8d5fdbab8d2017c21300b001e5ca..a5dcf031b2b6131fdc21a083cc7fb907f4348135 100644 (file)
@@ -24,6 +24,7 @@
 #include "libutil/str_util.h"
 #include "libserver/html/html.h"
 #include "libserver/hyperscan_tools.h"
+#include "libserver/async_session.h"
 
 #include "lua_parsers.h"
 
@@ -4474,9 +4475,17 @@ struct rspamd_ev_base_sleep_cbdata {
        lua_State *L;
        int cbref;
        struct thread_entry *thread;
+       struct rspamd_async_session *session;
        ev_timer ev;
 };
 
+/* Dummy finalizer for sleep session events - the timer handles cleanup */
+static void
+lua_ev_base_sleep_session_fin(gpointer ud)
+{
+       /* Nothing to do - timer callback handles everything */
+}
+
 static void
 lua_ev_base_sleep_cb(struct ev_loop *loop, struct ev_timer *t, int events)
 {
@@ -4485,6 +4494,12 @@ lua_ev_base_sleep_cb(struct ev_loop *loop, struct ev_timer *t, int events)
 
        ev_timer_stop(loop, t);
 
+       /* Remove session event if we registered one */
+       if (cbdata->session) {
+               rspamd_session_remove_event(cbdata->session,
+                                                                       lua_ev_base_sleep_session_fin, cbdata);
+       }
+
        if (cbdata->cbref != -1) {
                /* Async mode: call the callback */
                lua_State *L = cbdata->L;
@@ -4516,6 +4531,25 @@ lua_ev_base_sleep_cb(struct ev_loop *loop, struct ev_timer *t, int events)
  * @param {number} time timeout in seconds
  * @param {function} callback optional callback for async mode
  */
+/* Helper to get rspamadm_session from Lua globals if available */
+static struct rspamd_async_session *
+lua_get_rspamadm_session(lua_State *L)
+{
+       struct rspamd_async_session *session = NULL;
+
+       lua_getglobal(L, "rspamadm_session");
+
+       if (lua_type(L, -1) == LUA_TUSERDATA) {
+               void *ud = rspamd_lua_check_udata_maybe(L, -1, rspamd_session_classname);
+               if (ud) {
+                       session = *((struct rspamd_async_session **) ud);
+               }
+       }
+
+       lua_pop(L, 1);
+       return session;
+}
+
 static int
 lua_ev_base_sleep(lua_State *L)
 {
@@ -4534,12 +4568,23 @@ lua_ev_base_sleep(lua_State *L)
        cbdata->ev.data = cbdata;
        cbdata->cbref = -1;
        cbdata->thread = NULL;
+       cbdata->session = NULL;
+
+       /* Try to get rspamadm_session for session event tracking */
+       struct rspamd_async_session *session = lua_get_rspamadm_session(L);
 
        if (lua_isfunction(L, 3)) {
                /* Async mode with callback */
                lua_pushvalue(L, 3);
                cbdata->cbref = luaL_ref(L, LUA_REGISTRYINDEX);
 
+               /* Register session event if available so wait_session_events waits for us */
+               if (session) {
+                       cbdata->session = session;
+                       rspamd_session_add_event(session,
+                                                                        lua_ev_base_sleep_session_fin, cbdata, "lua sleep");
+               }
+
                ev_timer_init(&cbdata->ev, lua_ev_base_sleep_cb, timeout, 0.0);
                ev_timer_start(ev_base, &cbdata->ev);
 
@@ -4556,6 +4601,13 @@ lua_ev_base_sleep(lua_State *L)
                        if (cfg && cfg->lua_thread_pool) {
                                cbdata->thread = lua_thread_pool_get_running_entry(cfg->lua_thread_pool);
 
+                               /* Register session event if available so wait_session_events waits for us */
+                               if (session) {
+                                       cbdata->session = session;
+                                       rspamd_session_add_event(session,
+                                                                                        lua_ev_base_sleep_session_fin, cbdata, "lua sleep");
+                               }
+
                                ev_timer_init(&cbdata->ev, lua_ev_base_sleep_cb, timeout, 0.0);
                                ev_timer_start(ev_base, &cbdata->ev);
 
index 7054ab92328d554bfd9d1a0bf98dcb9885622f44..bf703ec89bd47237155da81b03109fe462554c11 100644 (file)
@@ -116,10 +116,24 @@ local function ann_scores_filter(task)
 
     if ann then
       local function after_features(vec, meta)
-        if profile.providers_digest and meta and meta.digest and profile.providers_digest ~= meta.digest then
-          lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply',
-            rule.prefix, set.name)
-          vec = nil
+        -- For providers-based ANNs, require matching digest
+        -- For symbols-based ANNs (no providers), skip this check
+        local has_providers = rule.providers and #rule.providers > 0
+        if has_providers then
+          local stored_digest = profile.providers_digest
+          local current_digest = meta and meta.digest
+          if not stored_digest then
+            -- Old ANN was trained without providers - needs retraining with current config
+            lua_util.debugm(N, task,
+              'ANN %s:%s was trained without providers, skipping (retrain with current config)',
+              rule.prefix, set.name)
+            vec = nil
+          elseif stored_digest ~= current_digest then
+            rspamd_logger.warnx(task,
+              'providers config changed for %s:%s (stored=%s, current=%s), ANN needs retraining',
+              rule.prefix, set.name, stored_digest, current_digest or 'none')
+            vec = nil
+          end
         end
 
         local score
@@ -350,7 +364,13 @@ local function ann_push_task_result(rule, task, verdict, score, set)
             end
 
             local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
-            local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
+            -- For manual training, use stable pending key to avoid version mismatch
+            local target_key
+            if manual_train then
+              target_key = neural_common.pending_train_key(rule, set) .. '_' .. learn_type .. '_set'
+            else
+              target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
+            end
 
             local function learn_vec_cb(redis_err)
               if redis_err then
@@ -405,13 +425,14 @@ local function ann_push_task_result(rule, task, verdict, score, set)
     end
 
     -- Check if we can learn
-    if set.can_store_vectors then
+    -- For manual training, bypass can_store_vectors check (it may not be set yet)
+    if set.can_store_vectors or manual_train then
       if not set.ann then
         -- Need to create or load a profile corresponding to the current configuration
         set.ann = new_ann_profile(task, rule, set, 0)
         lua_util.debugm(N, task,
-          'requested new profile for %s, set.ann is missing',
-          set.name)
+          'requested new profile for %s, set.ann is missing (manual_train=%s)',
+          set.name, manual_train)
       end
 
       lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
@@ -444,12 +465,14 @@ end
 
 -- This function does the following:
 -- * Tries to lock ANN
--- * Loads spam and ham vectors
+-- * Loads spam and ham vectors (from versioned key AND pending key)
 -- * Spawn learning process
 local function do_train_ann(worker, ev_base, rule, set, ann_key)
   local spam_elts = {}
   local ham_elts = {}
-  lua_util.debugm(N, rspamd_config, 'do_train_ann: start for %s:%s key=%s', rule.prefix, set.name, ann_key)
+  local pending_key = neural_common.pending_train_key(rule, set)
+  lua_util.debugm(N, rspamd_config, 'do_train_ann: start for %s:%s key=%s pending=%s',
+    rule.prefix, set.name, ann_key, pending_key)
 
   local function redis_ham_cb(err, data)
     if err or type(data) ~= 'table' then
@@ -475,7 +498,8 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
         set = set,
         ann_key = ann_key,
         ham_vec = ham_elts,
-        spam_vec = spam_elts
+        spam_vec = spam_elts,
+        pending_key = pending_key
       })
     end
   end
@@ -498,15 +522,15 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
     else
       -- Decompress and convert to numbers each training vector
       spam_elts = process_training_vectors(data)
-      -- Now get ham vectors...
+      -- Now get ham vectors from both versioned and pending keys
       lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
         false,        -- is write
         redis_ham_cb, --callback
-        'SMEMBERS',   -- command
-        { ann_key .. '_ham_set' }
+        'SUNION',     -- command (union of sets)
+        { ann_key .. '_ham_set', pending_key .. '_ham_set' }
       )
     end
   end
@@ -517,18 +541,19 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
         ann_key, err)
     elseif type(data) == 'number' and data == 1 then
       -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
+      -- Fetch from both versioned key and pending key using SUNION
       lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
         false,         -- is write
         redis_spam_cb, --callback
-        'SMEMBERS',    -- command
-        { ann_key .. '_spam_set' }
+        'SUNION',      -- command (union of sets)
+        { ann_key .. '_spam_set', pending_key .. '_spam_set' }
       )
 
-      rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
-        rule.prefix, set.name, ann_key)
+      rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s, pending %s) for learning',
+        rule.prefix, set.name, ann_key, pending_key)
     else
       local lock_tm = tonumber(data[1])
       rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
@@ -810,98 +835,132 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
   if sel_elt then
     -- We have our ANN and that's train vectors, check if we can learn
     local ann_key = sel_elt.redis_key
+    local pending_key = neural_common.pending_train_key(rule, set)
 
-    lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
-      ann_key)
-
-    -- Create continuation closure
-    local redis_len_cb_gen = function(cont_cb, what, is_final)
-      return function(err, data)
-        if err then
-          rspamd_logger.errx(rspamd_config,
-            'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
-        elseif data and type(data) == 'number' or type(data) == 'string' then
-          local ntrains = tonumber(data) or 0
-          lens[what] = ntrains
-          if is_final then
-            -- Ensure that we have the following:
-            -- one class has reached max_trains
-            -- other class(es) are at least as full as classes_bias
-            -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
-            -- one class must have 10 or more trains whilst another should have
-            -- at least (10 * (1 - 0.25)) = 8 trains
-
-            local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
-            local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
-
-            if rule.train.learn_type == 'balanced' then
-              local len_bias_check_pred = function(_, l)
-                return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
-              end
-              if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
-                lua_util.debugm(N, rspamd_config,
-                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                  ann_key, lens, rule.train.max_trains, what)
-                cont_cb()
-              else
-                lua_util.debugm(N, rspamd_config,
-                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                  ann_key, what, lens, rule.train.max_trains)
-              end
-            else
-              -- Probabilistic mode, just ensure that at least one vector is okay
-              if min_len > 0 and max_len >= rule.train.max_trains then
-                lua_util.debugm(N, rspamd_config,
-                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                  ann_key, lens, rule.train.max_trains, what)
-                cont_cb()
-              else
-                lua_util.debugm(N, rspamd_config,
-                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                  ann_key, what, lens, rule.train.max_trains)
-              end
-            end
-          else
-            lua_util.debugm(N, rspamd_config,
-              'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
-              what, ann_key, ntrains, rule.train.max_trains)
-            cont_cb()
-          end
-        end
-      end
-    end
+    lua_util.debugm(N, rspamd_config, "check if ANN %s (pending %s) needs to be trained",
+      ann_key, pending_key)
 
     local function initiate_train()
       rspamd_logger.infox(rspamd_config,
-        'need to learn ANN %s after %s required learn vectors',
-        ann_key, lens)
+        'need to learn ANN %s (pending %s) after %s required learn vectors',
+        ann_key, pending_key, lens)
       lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: initiating train for key=%s spam=%s ham=%s', ann_key,
         lens.spam or -1, lens.ham or -1)
       do_train_ann(worker, ev_base, rule, set, ann_key)
     end
 
-    -- Spam vector is OK, check ham vector length
+    -- Final check after all vectors are counted
+    local function maybe_initiate_train()
+      local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
+      local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
+
+      lua_util.debugm(N, rspamd_config,
+        'final vector count for ANN %s: spam=%s ham=%s (min=%s max=%s required=%s)',
+        ann_key, lens.spam, lens.ham, min_len, max_len, rule.train.max_trains)
+
+      if rule.train.learn_type == 'balanced' then
+        local len_bias_check_pred = function(_, l)
+          return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
+        end
+        if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
+          initiate_train()
+        else
+          lua_util.debugm(N, rspamd_config,
+            'cannot learn ANN %s: balanced mode requires more vectors (has %s)',
+            ann_key, lens)
+        end
+      else
+        -- Probabilistic mode
+        if min_len > 0 and max_len >= rule.train.max_trains then
+          initiate_train()
+        else
+          lua_util.debugm(N, rspamd_config,
+            'cannot learn ANN %s: need min_len > 0 and max_len >= %s (has %s)',
+            ann_key, rule.train.max_trains, lens)
+        end
+      end
+    end
+
+    -- Callback that adds count from pending key and continues
+    local function add_pending_cb(cont_cb, what)
+      return function(err, data)
+        if not err and (type(data) == 'number' or type(data) == 'string') then
+          local pending_count = tonumber(data) or 0
+          lens[what] = (lens[what] or 0) + pending_count
+          lua_util.debugm(N, rspamd_config, 'added %s pending %s vectors, total now %s',
+            pending_count, what, lens[what])
+        end
+        cont_cb()
+      end
+    end
+
+    -- Simple callback that just adds versioned count and continues
+    local function add_versioned_cb(cont_cb, what)
+      return function(err, data)
+        if not err and (type(data) == 'number' or type(data) == 'string') then
+          local count = tonumber(data) or 0
+          lens[what] = (lens[what] or 0) + count
+          lua_util.debugm(N, rspamd_config, 'added %s versioned %s vectors, total now %s',
+            count, what, lens[what])
+        end
+        cont_cb()
+      end
+    end
+
+    -- Check pending ham, then make final decision
+    local function check_pending_ham()
+      lua_redis.redis_make_request_taskless(ev_base,
+        rspamd_config,
+        rule.redis,
+        nil,
+        false,
+        add_pending_cb(maybe_initiate_train, 'ham'),
+        'SCARD',
+        { pending_key .. '_ham_set' }
+      )
+    end
+
+    -- Check versioned ham, then check pending ham
     local function check_ham_len()
       lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
-        false,                                         -- is write
-        redis_len_cb_gen(initiate_train, 'ham', true), --callback
-        'SCARD',                                       -- command
+        false,
+        add_versioned_cb(check_pending_ham, 'ham'),
+        'SCARD',
         { ann_key .. '_ham_set' }
       )
     end
 
-    lua_redis.redis_make_request_taskless(ev_base,
-      rspamd_config,
-      rule.redis,
-      nil,
-      false,                                          -- is write
-      redis_len_cb_gen(check_ham_len, 'spam', false), --callback
-      'SCARD',                                        -- command
-      { ann_key .. '_spam_set' }
-    )
+    -- Check pending spam, then check ham
+    local function check_pending_spam()
+      lua_redis.redis_make_request_taskless(ev_base,
+        rspamd_config,
+        rule.redis,
+        nil,
+        false,
+        add_pending_cb(check_ham_len, 'spam'),
+        'SCARD',
+        { pending_key .. '_spam_set' }
+      )
+    end
+
+    -- Check versioned spam, then pending spam
+    local function check_spam_len()
+      lua_redis.redis_make_request_taskless(ev_base,
+        rspamd_config,
+        rule.redis,
+        nil,
+        false,
+        add_versioned_cb(check_pending_spam, 'spam'),
+        'SCARD',
+        { ann_key .. '_spam_set' }
+      )
+    end
+
+    -- Start the chain
+    check_spam_len()
   end
 end
 
@@ -999,7 +1058,9 @@ local function ann_push_vector(task)
     lua_util.debugm(N, task, 'do not push data for skipped task')
     return
   end
-  if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then
+  -- Allow manual training via ANN-Train header regardless of allow_local
+  local manual_train_header = get_ann_train_header(task)
+  if not settings.allow_local and not manual_train_header and lua_util.is_rspamc_or_controller(task) then
     lua_util.debugm(N, task, 'do not push data for manual scan')
     return
   end