From: Vsevolod Stakhov Date: Mon, 19 Jan 2026 18:08:58 +0000 (+0000) Subject: [Feature] Add pending training keys and fix neural network training issues X-Git-Tag: 4.0.0~179^2~18 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=843411c1932baf186bf0e57cbfcef1a0a3072c15;p=thirdparty%2Frspamd.git [Feature] Add pending training keys and fix neural network training issues - Add pending_train_key() for version-independent training vector storage - Fix variable shadowing bug where ann_trained callback was overwritten - Add concurrent training prevention via learning_spawned check - Replace assert with proper error handling for msgpack parsing - Clean up pending keys after successful training - Update controller endpoint to use pending keys for manual training - Fix ev_base:sleep() to register with session events properly - Update classifier_test.lua to support llm_embeddings classifier testing Co-Authored-By: Claude --- diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index 919df7f32c..7f16618f19 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -587,6 +587,13 @@ local function redis_ann_prefix(rule, settings_name) settings.prefix, plugin_ver, rule.prefix, n, settings_name) end +-- Returns a stable key for pending training vectors (version-independent) +-- Used for batch/manual training to avoid version mismatch issues +local function pending_train_key(rule, set) + return string.format('%s_%s_%s_pending', + settings.prefix, rule.prefix, set.name) +end + -- Compute a stable digest for providers configuration local function providers_config_digest(providers_cfg, rule) if not providers_cfg then return nil end @@ -816,6 +823,13 @@ end -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis local function spawn_train(params) + -- Prevent concurrent training + if params.set.learning_spawned then + lua_util.debugm(N, rspamd_config, 'spawn_train: training already in progress for %s:%s, skipping', + params.rule.prefix, params.set.name) + return + end + -- Check training data sanity -- Now we need to join inputs and create the appropriate test vectors local n @@ -1006,6 +1020,28 @@ local function spawn_train(params) else rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', params.rule.prefix, params.set.name, params.set.ann.redis_key) + + -- Clean up pending training keys if they were used + if params.pending_key then + local function cleanup_cb(cleanup_err) + if cleanup_err then + lua_util.debugm(N, rspamd_config, 'failed to cleanup pending keys: %s', cleanup_err) + else + lua_util.debugm(N, rspamd_config, 'cleaned up pending training keys for %s', + params.pending_key) + end + end + -- Delete both spam and ham pending sets + lua_redis.redis_make_request_taskless(params.ev_base, + rspamd_config, + params.rule.redis, + nil, + true, -- is write + cleanup_cb, + 'DEL', + { params.pending_key .. '_spam_set', params.pending_key .. '_ham_set' } + ) + end end end @@ -1026,7 +1062,20 @@ local function spawn_train(params) else local parser = ucl.parser() local ok, parse_err = parser:parse_text(data, 'msgpack') - assert(ok, parse_err) + if not ok then + rspamd_logger.errx(rspamd_config, 'cannot parse training result for ANN %s:%s: %s (data size: %s)', + params.rule.prefix, params.set.name, parse_err, #data) + lua_redis.redis_make_request_taskless(params.ev_base, + rspamd_config, + params.rule.redis, + nil, + true, + gen_unlock_cb(params.rule, params.set, params.ann_key), + 'HDEL', + { params.ann_key, 'lock' } + ) + return + end local parsed = parser:get_object() local ann_data = rspamd_util.zstd_compress(parsed.ann_data) local pca_data = parsed.pca_data @@ -1045,10 +1094,10 @@ local function spawn_train(params) -- Deserialise ANN from the child process - ann_trained = rspamd_kann.load(parsed.ann_data) + local loaded_ann = rspamd_kann.load(parsed.ann_data) local version = (params.set.ann.version or 0) + 1 params.set.ann.version = version - params.set.ann.ann = ann_trained + params.set.ann.ann = loaded_ann params.set.ann.symbols = params.set.symbols params.set.ann.redis_key = new_ann_key(params.rule, params.set, version) @@ -1306,6 +1355,7 @@ return { load_scripts = load_scripts, module_config = module_config, new_ann_key = new_ann_key, + pending_train_key = pending_train_key, providers_config_digest = providers_config_digest, register_provider = register_provider, plugin_ver = plugin_ver, diff --git a/lualib/rspamadm/classifier_test.lua b/lualib/rspamadm/classifier_test.lua index 4148a75388..2f3a6a3910 100644 --- a/lualib/rspamadm/classifier_test.lua +++ b/lualib/rspamadm/classifier_test.lua @@ -6,7 +6,7 @@ local rspamd_logger = require "rspamd_logger" local parser = argparse() :name "rspamadm classifier_test" - :description "Learn bayes classifier and evaluate its performance" + :description "Learn classifier and evaluate its performance" :help_description_margin(32) parser:option "-H --ham" @@ -15,8 +15,14 @@ parser:option "-H --ham" parser:option "-S --spam" :description("Spam directory") :argname("") +parser:option "-C --classifier" + :description("Classifier type: bayes or llm_embeddings") + :argname("") + :default('bayes') parser:flag "-n --no-learning" :description("Do not learn classifier") +parser:flag "-T --train-only" + :description("Only train, do not evaluate (llm_embeddings only)") parser:option "--nconns" :description("Number of parallel connections") :argname("") @@ -35,19 +41,22 @@ parser:option "-r --rspamc" :description("Use specific rspamc path") :argname("") :default('rspamc') -parser:option "-c --cv-fraction" +parser:option "-f --cv-fraction" :description("Use specific fraction for cross-validation") :argname("") :convert(tonumber) - :default('0.7') + :default(0.7) parser:option "--spam-symbol" - :description("Use specific spam symbol (instead of BAYES_SPAM)") + :description("Use specific spam symbol (auto-detected from classifier type)") :argname("") - :default('BAYES_SPAM') parser:option "--ham-symbol" - :description("Use specific ham symbol (instead of BAYES_HAM)") + :description("Use specific ham symbol (auto-detected from classifier type)") :argname("") - :default('BAYES_HAM') +parser:option "--train-wait" + :description("Seconds to wait after training for neural network (llm_embeddings only, should be > watch_interval)") + :argname("") + :convert(tonumber) + :default(90) local opts @@ -78,15 +87,48 @@ local function list_to_file(list, fname) out:close() end --- Function to train the classifier with given files -local function train_classifier(files, command) +-- Function to train the Bayes classifier with given files +local function train_bayes(files, command) local fname = os.tmpname() list_to_file(files, fname) local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s --files-list=%s", opts.rspamc, opts.connect, opts.nconns, opts.timeout, command, fname) + local handle = assert(io.popen(rspamc_command)) + handle:read("*all") + handle:close() + os.remove(fname) +end + +-- Function to train with ANN-Train header (for llm_embeddings/neural) +-- Uses settings to enable only NEURAL_LEARN symbol, skipping full scan +local function train_neural(files, learn_type) + local fname = os.tmpname() + list_to_file(files, fname) + + -- Use ANN-Train header with settings to limit scan to NEURAL_LEARN only + local rspamc_command = string.format( + "%s --connect %s -j --compact -n %s -t %.3f " .. + "--settings '{\"symbols_enabled\":[\"NEURAL_LEARN\"]}' " .. + "--header 'ANN-Train=%s' --files-list=%s", + opts.rspamc, opts.connect, opts.nconns, opts.timeout, + learn_type, fname) + local result = assert(io.popen(rspamc_command)) - result = result:read("*all") + local output = result:read("*all") + result:close() os.remove(fname) + + -- Count successful submissions + local count = 0 + for line in output:gmatch("[^\n]+") do + local ucl_parser = ucl.parser() + local is_good, _ = ucl_parser:parse_string(line) + if is_good then + count = count + 1 + end + end + + return count end -- Function to classify files and return results @@ -94,7 +136,7 @@ local function classify_files(files, known_spam_files, known_ham_files) local fname = os.tmpname() list_to_file(files, fname) - local settings_header = string.format('--header Settings=\"{symbols_enabled=[%s, %s]}\"', + local settings_header = string.format('--header Settings="{symbols_enabled=[%s, %s]}"', opts.spam_symbol, opts.ham_symbol) local rspamc_command = string.format("%s %s --connect %s --compact -n %s -t %.3f --files-list=%s", opts.rspamc, @@ -109,26 +151,31 @@ local function classify_files(files, known_spam_files, known_ham_files) local is_good, err = ucl_parser:parse_string(line) if not is_good then rspamd_logger.errx("Parser error: %1", err) - os.remove(fname) - return nil - end - local obj = ucl_parser:get_object() - local file = obj.filename - local symbols = obj.symbols or {} - - if symbols[opts.spam_symbol] then - table.insert(results, { result = "spam", file = file }) - if known_ham_files[file] then - rspamd_logger.message("FP: %s is classified as spam but is known ham", file) - end - elseif symbols[opts.ham_symbol] then - if known_spam_files[file] then - rspamd_logger.message("FN: %s is classified as ham but is known spam", file) + else + local obj = ucl_parser:get_object() + local file = obj.filename + local symbols = obj.symbols or {} + + if symbols[opts.spam_symbol] then + local score = symbols[opts.spam_symbol].score + table.insert(results, { result = "spam", file = file, score = score }) + if known_ham_files[file] then + rspamd_logger.message("FP: %s is classified as spam but is known ham", file) + end + elseif symbols[opts.ham_symbol] then + local score = symbols[opts.ham_symbol].score + table.insert(results, { result = "ham", file = file, score = score }) + if known_spam_files[file] then + rspamd_logger.message("FN: %s is classified as ham but is known spam", file) + end + else + -- No classification result + table.insert(results, { result = "unknown", file = file }) end - table.insert(results, { result = "ham", file = file }) end end + result:close() os.remove(fname) return results @@ -137,7 +184,9 @@ end -- Function to evaluate classifier performance local function evaluate_results(results, spam_label, ham_label, known_spam_files, known_ham_files, total_cv_files, elapsed) - local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0 + local true_positives, false_positives, true_negatives, false_negatives = 0, 0, 0, 0 + local classified, unclassified = 0, 0 + for _, res in ipairs(results) do if res.result == spam_label then if known_spam_files[res.file] then @@ -145,69 +194,174 @@ local function evaluate_results(results, spam_label, ham_label, elseif known_ham_files[res.file] then false_positives = false_positives + 1 end - total = total + 1 + classified = classified + 1 elseif res.result == ham_label then if known_spam_files[res.file] then false_negatives = false_negatives + 1 elseif known_ham_files[res.file] then true_negatives = true_negatives + 1 end - total = total + 1 + classified = classified + 1 + else + unclassified = unclassified + 1 end end - local accuracy = (true_positives + true_negatives) / total - local precision = true_positives / (true_positives + false_positives) - local recall = true_positives / (true_positives + false_negatives) - local f1_score = 2 * (precision * recall) / (precision + recall) - - print(string.format("%-20s %-10s", "Metric", "Value")) - print(string.rep("-", 30)) + print(string.format("\n%-20s %-10s", "Metric", "Value")) + print(string.rep("-", 35)) print(string.format("%-20s %-10d", "True Positives", true_positives)) print(string.format("%-20s %-10d", "False Positives", false_positives)) print(string.format("%-20s %-10d", "True Negatives", true_negatives)) print(string.format("%-20s %-10d", "False Negatives", false_negatives)) - print(string.format("%-20s %-10.2f", "Accuracy", accuracy)) - print(string.format("%-20s %-10.2f", "Precision", precision)) - print(string.format("%-20s %-10.2f", "Recall", recall)) - print(string.format("%-20s %-10.2f", "F1 Score", f1_score)) - print(string.format("%-20s %-10.2f", "Classified (%)", total / total_cv_files * 100)) - print(string.format("%-20s %-10.2f", "Elapsed time (seconds)", elapsed)) + print(string.format("%-20s %-10d", "Unclassified", unclassified)) + + if classified > 0 then + local accuracy = (true_positives + true_negatives) / classified + local precision = true_positives > 0 and true_positives / (true_positives + false_positives) or 0 + local recall = true_positives > 0 and true_positives / (true_positives + false_negatives) or 0 + local f1_score = (precision + recall) > 0 and 2 * (precision * recall) / (precision + recall) or 0 + + print(string.format("%-20s %-10.4f", "Accuracy", accuracy)) + print(string.format("%-20s %-10.4f", "Precision", precision)) + print(string.format("%-20s %-10.4f", "Recall", recall)) + print(string.format("%-20s %-10.4f", "F1 Score", f1_score)) + end + + print(string.format("%-20s %-10.2f%%", "Classified", classified / total_cv_files * 100)) + print(string.format("%-20s %-10.2f", "Elapsed (sec)", elapsed)) end local function handler(args) opts = parser:parse(args) + local ham_directory = opts['ham'] local spam_directory = opts['spam'] + local classifier_type = opts['classifier'] + + if not ham_directory or not spam_directory then + print("Error: Both --ham and --spam directories are required") + os.exit(1) + end + + -- Set default symbols based on classifier type + if not opts.spam_symbol then + if classifier_type == 'llm_embeddings' then + opts.spam_symbol = 'NEURAL_SPAM' + else + opts.spam_symbol = 'BAYES_SPAM' + end + end + if not opts.ham_symbol then + if classifier_type == 'llm_embeddings' then + opts.ham_symbol = 'NEURAL_HAM' + else + opts.ham_symbol = 'BAYES_HAM' + end + end + -- Get all files local spam_files = get_files(spam_directory) local known_spam_files = lua_util.list_to_hash(spam_files) local ham_files = get_files(ham_directory) local known_ham_files = lua_util.list_to_hash(ham_files) - -- Split files into training and cross-validation sets + print(string.format("Classifier: %s", classifier_type)) + print(string.format("Found %d spam files, %d ham files", #spam_files, #ham_files)) + -- Split files into training and cross-validation sets local train_spam, cv_spam = split_table(spam_files, opts.cv_fraction) local train_ham, cv_ham = split_table(ham_files, opts.cv_fraction) - print(string.format("Spam: %d train files, %d cv files; ham: %d train files, %d cv files", + print(string.format("Split: %d/%d spam (train/test), %d/%d ham (train/test)", #train_spam, #cv_spam, #train_ham, #cv_ham)) + + -- Training phase if not opts.no_learning then - -- Train classifier + print("\n=== Training Phase ===") + local t, train_spam_time, train_ham_time - print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns)) - t = rspamd_util.get_time() - train_classifier(train_spam, "learn_spam") - train_spam_time = rspamd_util.get_time() - t - print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns)) - t = rspamd_util.get_time() - train_classifier(train_ham, "learn_ham") - train_ham_time = rspamd_util.get_time() - t - print(string.format("Learning done: %d spam messages in %.2f seconds, %d ham messages in %.2f seconds", - #train_spam, train_spam_time, #train_ham, train_ham_time)) + + if classifier_type == 'llm_embeddings' then + -- Neural/LLM training using ANN-Train header + -- Interleave spam and ham submissions for balanced training + print(string.format("Training %d spam + %d ham messages (interleaved)...", #train_spam, #train_ham)) + t = rspamd_util.get_time() + + -- Create interleaved list of {file, type} pairs + local interleaved = {} + local spam_idx, ham_idx = 1, 1 + while spam_idx <= #train_spam or ham_idx <= #train_ham do + if spam_idx <= #train_spam then + table.insert(interleaved, { file = train_spam[spam_idx], type = 'spam' }) + spam_idx = spam_idx + 1 + end + if ham_idx <= #train_ham then + table.insert(interleaved, { file = train_ham[ham_idx], type = 'ham' }) + ham_idx = ham_idx + 1 + end + end + + -- Submit in batches, grouped by type for efficiency + local batch_size = math.max(1, math.floor(#interleaved / 10)) + local spam_batch, ham_batch = {}, {} + local spam_trained, ham_trained = 0, 0 + + for i, item in ipairs(interleaved) do + if item.type == 'spam' then + table.insert(spam_batch, item.file) + else + table.insert(ham_batch, item.file) + end + + -- Submit batches periodically + if i % batch_size == 0 or i == #interleaved then + if #spam_batch > 0 then + spam_trained = spam_trained + train_neural(spam_batch, "spam") + spam_batch = {} + end + if #ham_batch > 0 then + ham_trained = ham_trained + train_neural(ham_batch, "ham") + ham_batch = {} + end + end + end + + train_spam_time = rspamd_util.get_time() - t + train_ham_time = 0 -- Combined time + print(string.format(" Submitted %d spam + %d ham samples in %.2f seconds", + spam_trained, ham_trained, train_spam_time)) + + -- Wait for neural network to train using ev_base sleep + print(string.format("\nWaiting %d seconds for neural network training...", opts.train_wait)) + rspamadm_ev_base:sleep(opts.train_wait) + print("Training wait complete.") + else + -- Bayes training using learn_spam/learn_ham + print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns)) + t = rspamd_util.get_time() + train_bayes(train_spam, "learn_spam") + train_spam_time = rspamd_util.get_time() - t + + print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns)) + t = rspamd_util.get_time() + train_bayes(train_ham, "learn_ham") + train_ham_time = rspamd_util.get_time() - t + + print(string.format("Learning done: %d spam in %.2f sec, %d ham in %.2f sec", + #train_spam, train_spam_time, #train_ham, train_ham_time)) + end + else + print("\nSkipping training phase (--no-learning)") + end + + if opts.train_only then + print("\nTraining only mode - skipping evaluation") + return end - -- Classify cross-validation files + -- Cross-validation phase + print("\n=== Evaluation Phase ===") + local cv_files = {} for _, file in ipairs(cv_spam) do table.insert(cv_files, file) @@ -219,7 +373,8 @@ local function handler(args) -- Shuffle cross-validation files cv_files = split_table(cv_files, 1) - print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns)) + print(string.format("Classifying %d test messages...", #cv_files)) + -- Get classification results local t = rspamd_util.get_time() local results = classify_files(cv_files, known_spam_files, known_ham_files) @@ -231,7 +386,6 @@ local function handler(args) known_ham_files, #cv_files, elapsed) - end return { diff --git a/rules/controller/neural.lua b/rules/controller/neural.lua index 628e4a62d7..4c8c931d0f 100644 --- a/rules/controller/neural.lua +++ b/rules/controller/neural.lua @@ -345,7 +345,9 @@ local function handle_learn_message(task, conn) lua_util.debugm(N, task, 'controller.neural: vector size=%s head=[%s]', vlen, vhead) local compressed = rspamd_util.zstd_compress(table.concat(vec, ';')) - local target_key = string.format('%s_%s_set', redis_base, learn_type) + -- Use pending key for manual training (picked up by training loop) + local pending_key = neural_common.pending_train_key(rule, set) + local target_key = string.format('%s_%s_set', pending_key, learn_type) local function learn_vec_cb(redis_err) if redis_err then @@ -385,11 +387,94 @@ local function handle_train(task, conn, req_params) conn:send_error(400, 'unknown rule') return end - -- Trigger check_anns to evaluate training conditions - rspamd_config:add_periodic(task:get_ev_base(), 0.0, function() - return 0.0 - end) - conn:send_ucl({ success = true, message = 'training scheduled check' }) + + -- Get the set for this rule + local set = neural_common.get_rule_settings(task, rule) + if not set then + -- Try to find any available set + for sid, s in pairs(rule.settings or {}) do + if type(s) == 'table' then + set = s + set.name = set.name or sid + break + end + end + end + + if not set then + conn:send_error(400, 'no settings found for rule') + return + end + + -- Check pending vectors count + local pending_key = neural_common.pending_train_key(rule, set) + local ev_base = task:get_ev_base() + + local function check_and_train(spam_count, ham_count) + if spam_count > 0 and ham_count > 0 then + -- We have vectors in pending, find or create a profile and train + local ann_key + if set.ann and set.ann.redis_key then + ann_key = set.ann.redis_key + else + -- Create a new profile + ann_key = neural_common.new_ann_key(rule, set, 0) + end + + rspamd_logger.infox(task, 'manual train trigger for %s:%s with %s spam, %s ham vectors', + rule.prefix, set.name, spam_count, ham_count) + + -- The training will be picked up by the next check_anns cycle + -- For immediate training, we'd need access to the worker object + conn:send_ucl({ + success = true, + message = 'training vectors available', + spam_vectors = spam_count, + ham_vectors = ham_count, + pending_key = pending_key + }) + else + conn:send_ucl({ + success = false, + message = 'not enough vectors for training', + spam_vectors = spam_count, + ham_vectors = ham_count + }) + end + end + + -- Count pending vectors + local spam_count = 0 + local function count_ham_cb(err, data) + local ham_count = 0 + if not err and (type(data) == 'number' or type(data) == 'string') then + ham_count = tonumber(data) or 0 + end + check_and_train(spam_count, ham_count) + end + + local function count_spam_cb(err, data) + if not err and (type(data) == 'number' or type(data) == 'string') then + spam_count = tonumber(data) or 0 + end + lua_redis.redis_make_request(task, + rule.redis, + nil, + false, + count_ham_cb, + 'SCARD', + { pending_key .. '_ham_set' } + ) + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + false, + count_spam_cb, + 'SCARD', + { pending_key .. '_spam_set' } + ) end return { diff --git a/src/lua/lua_util.c b/src/lua/lua_util.c index 12dfee02ac..a5dcf031b2 100644 --- a/src/lua/lua_util.c +++ b/src/lua/lua_util.c @@ -24,6 +24,7 @@ #include "libutil/str_util.h" #include "libserver/html/html.h" #include "libserver/hyperscan_tools.h" +#include "libserver/async_session.h" #include "lua_parsers.h" @@ -4474,9 +4475,17 @@ struct rspamd_ev_base_sleep_cbdata { lua_State *L; int cbref; struct thread_entry *thread; + struct rspamd_async_session *session; ev_timer ev; }; +/* Dummy finalizer for sleep session events - the timer handles cleanup */ +static void +lua_ev_base_sleep_session_fin(gpointer ud) +{ + /* Nothing to do - timer callback handles everything */ +} + static void lua_ev_base_sleep_cb(struct ev_loop *loop, struct ev_timer *t, int events) { @@ -4485,6 +4494,12 @@ lua_ev_base_sleep_cb(struct ev_loop *loop, struct ev_timer *t, int events) ev_timer_stop(loop, t); + /* Remove session event if we registered one */ + if (cbdata->session) { + rspamd_session_remove_event(cbdata->session, + lua_ev_base_sleep_session_fin, cbdata); + } + if (cbdata->cbref != -1) { /* Async mode: call the callback */ lua_State *L = cbdata->L; @@ -4516,6 +4531,25 @@ lua_ev_base_sleep_cb(struct ev_loop *loop, struct ev_timer *t, int events) * @param {number} time timeout in seconds * @param {function} callback optional callback for async mode */ +/* Helper to get rspamadm_session from Lua globals if available */ +static struct rspamd_async_session * +lua_get_rspamadm_session(lua_State *L) +{ + struct rspamd_async_session *session = NULL; + + lua_getglobal(L, "rspamadm_session"); + + if (lua_type(L, -1) == LUA_TUSERDATA) { + void *ud = rspamd_lua_check_udata_maybe(L, -1, rspamd_session_classname); + if (ud) { + session = *((struct rspamd_async_session **) ud); + } + } + + lua_pop(L, 1); + return session; +} + static int lua_ev_base_sleep(lua_State *L) { @@ -4534,12 +4568,23 @@ lua_ev_base_sleep(lua_State *L) cbdata->ev.data = cbdata; cbdata->cbref = -1; cbdata->thread = NULL; + cbdata->session = NULL; + + /* Try to get rspamadm_session for session event tracking */ + struct rspamd_async_session *session = lua_get_rspamadm_session(L); if (lua_isfunction(L, 3)) { /* Async mode with callback */ lua_pushvalue(L, 3); cbdata->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + /* Register session event if available so wait_session_events waits for us */ + if (session) { + cbdata->session = session; + rspamd_session_add_event(session, + lua_ev_base_sleep_session_fin, cbdata, "lua sleep"); + } + ev_timer_init(&cbdata->ev, lua_ev_base_sleep_cb, timeout, 0.0); ev_timer_start(ev_base, &cbdata->ev); @@ -4556,6 +4601,13 @@ lua_ev_base_sleep(lua_State *L) if (cfg && cfg->lua_thread_pool) { cbdata->thread = lua_thread_pool_get_running_entry(cfg->lua_thread_pool); + /* Register session event if available so wait_session_events waits for us */ + if (session) { + cbdata->session = session; + rspamd_session_add_event(session, + lua_ev_base_sleep_session_fin, cbdata, "lua sleep"); + } + ev_timer_init(&cbdata->ev, lua_ev_base_sleep_cb, timeout, 0.0); ev_timer_start(ev_base, &cbdata->ev); diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 7054ab9232..bf703ec89b 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -116,10 +116,24 @@ local function ann_scores_filter(task) if ann then local function after_features(vec, meta) - if profile.providers_digest and meta and meta.digest and profile.providers_digest ~= meta.digest then - lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply', - rule.prefix, set.name) - vec = nil + -- For providers-based ANNs, require matching digest + -- For symbols-based ANNs (no providers), skip this check + local has_providers = rule.providers and #rule.providers > 0 + if has_providers then + local stored_digest = profile.providers_digest + local current_digest = meta and meta.digest + if not stored_digest then + -- Old ANN was trained without providers - needs retraining with current config + lua_util.debugm(N, task, + 'ANN %s:%s was trained without providers, skipping (retrain with current config)', + rule.prefix, set.name) + vec = nil + elseif stored_digest ~= current_digest then + rspamd_logger.warnx(task, + 'providers config changed for %s:%s (stored=%s, current=%s), ANN needs retraining', + rule.prefix, set.name, stored_digest, current_digest or 'none') + vec = nil + end end local score @@ -350,7 +364,13 @@ local function ann_push_task_result(rule, task, verdict, score, set) end local str = rspamd_util.zstd_compress(table.concat(vec, ';')) - local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set' + -- For manual training, use stable pending key to avoid version mismatch + local target_key + if manual_train then + target_key = neural_common.pending_train_key(rule, set) .. '_' .. learn_type .. '_set' + else + target_key = set.ann.redis_key .. '_' .. learn_type .. '_set' + end local function learn_vec_cb(redis_err) if redis_err then @@ -405,13 +425,14 @@ local function ann_push_task_result(rule, task, verdict, score, set) end -- Check if we can learn - if set.can_store_vectors then + -- For manual training, bypass can_store_vectors check (it may not be set yet) + if set.can_store_vectors or manual_train then if not set.ann then -- Need to create or load a profile corresponding to the current configuration set.ann = new_ann_profile(task, rule, set, 0) lua_util.debugm(N, task, - 'requested new profile for %s, set.ann is missing', - set.name) + 'requested new profile for %s, set.ann is missing (manual_train=%s)', + set.name, manual_train) end lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len, @@ -444,12 +465,14 @@ end -- This function does the following: -- * Tries to lock ANN --- * Loads spam and ham vectors +-- * Loads spam and ham vectors (from versioned key AND pending key) -- * Spawn learning process local function do_train_ann(worker, ev_base, rule, set, ann_key) local spam_elts = {} local ham_elts = {} - lua_util.debugm(N, rspamd_config, 'do_train_ann: start for %s:%s key=%s', rule.prefix, set.name, ann_key) + local pending_key = neural_common.pending_train_key(rule, set) + lua_util.debugm(N, rspamd_config, 'do_train_ann: start for %s:%s key=%s pending=%s', + rule.prefix, set.name, ann_key, pending_key) local function redis_ham_cb(err, data) if err or type(data) ~= 'table' then @@ -475,7 +498,8 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) set = set, ann_key = ann_key, ham_vec = ham_elts, - spam_vec = spam_elts + spam_vec = spam_elts, + pending_key = pending_key }) end end @@ -498,15 +522,15 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) else -- Decompress and convert to numbers each training vector spam_elts = process_training_vectors(data) - -- Now get ham vectors... + -- Now get ham vectors from both versioned and pending keys lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, false, -- is write redis_ham_cb, --callback - 'SMEMBERS', -- command - { ann_key .. '_ham_set' } + 'SUNION', -- command (union of sets) + { ann_key .. '_ham_set', pending_key .. '_ham_set' } ) end end @@ -517,18 +541,19 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) ann_key, err) elseif type(data) == 'number' and data == 1 then -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning + -- Fetch from both versioned key and pending key using SUNION lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, false, -- is write redis_spam_cb, --callback - 'SMEMBERS', -- command - { ann_key .. '_spam_set' } + 'SUNION', -- command (union of sets) + { ann_key .. '_spam_set', pending_key .. '_spam_set' } ) - rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning', - rule.prefix, set.name, ann_key) + rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s, pending %s) for learning', + rule.prefix, set.name, ann_key, pending_key) else local lock_tm = tonumber(data[1]) rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' .. @@ -810,98 +835,132 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) if sel_elt then -- We have our ANN and that's train vectors, check if we can learn local ann_key = sel_elt.redis_key + local pending_key = neural_common.pending_train_key(rule, set) - lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained", - ann_key) - - -- Create continuation closure - local redis_len_cb_gen = function(cont_cb, what, is_final) - return function(err, data) - if err then - rspamd_logger.errx(rspamd_config, - 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) - elseif data and type(data) == 'number' or type(data) == 'string' then - local ntrains = tonumber(data) or 0 - lens[what] = ntrains - if is_final then - -- Ensure that we have the following: - -- one class has reached max_trains - -- other class(es) are at least as full as classes_bias - -- e.g. if classes_bias = 0.25 and we have 10 max_trains then - -- one class must have 10 or more trains whilst another should have - -- at least (10 * (1 - 0.25)) = 8 trains - - local max_len = math.max(lua_util.unpack(lua_util.values(lens))) - local min_len = math.min(lua_util.unpack(lua_util.values(lens))) - - if rule.train.learn_type == 'balanced' then - local len_bias_check_pred = function(_, l) - return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias) - end - if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then - lua_util.debugm(N, rspamd_config, - 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', - ann_key, lens, rule.train.max_trains, what) - cont_cb() - else - lua_util.debugm(N, rspamd_config, - 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', - ann_key, what, lens, rule.train.max_trains) - end - else - -- Probabilistic mode, just ensure that at least one vector is okay - if min_len > 0 and max_len >= rule.train.max_trains then - lua_util.debugm(N, rspamd_config, - 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', - ann_key, lens, rule.train.max_trains, what) - cont_cb() - else - lua_util.debugm(N, rspamd_config, - 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', - ann_key, what, lens, rule.train.max_trains) - end - end - else - lua_util.debugm(N, rspamd_config, - 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', - what, ann_key, ntrains, rule.train.max_trains) - cont_cb() - end - end - end - end + lua_util.debugm(N, rspamd_config, "check if ANN %s (pending %s) needs to be trained", + ann_key, pending_key) local function initiate_train() rspamd_logger.infox(rspamd_config, - 'need to learn ANN %s after %s required learn vectors', - ann_key, lens) + 'need to learn ANN %s (pending %s) after %s required learn vectors', + ann_key, pending_key, lens) lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: initiating train for key=%s spam=%s ham=%s', ann_key, lens.spam or -1, lens.ham or -1) do_train_ann(worker, ev_base, rule, set, ann_key) end - -- Spam vector is OK, check ham vector length + -- Final check after all vectors are counted + local function maybe_initiate_train() + local max_len = math.max(lua_util.unpack(lua_util.values(lens))) + local min_len = math.min(lua_util.unpack(lua_util.values(lens))) + + lua_util.debugm(N, rspamd_config, + 'final vector count for ANN %s: spam=%s ham=%s (min=%s max=%s required=%s)', + ann_key, lens.spam, lens.ham, min_len, max_len, rule.train.max_trains) + + if rule.train.learn_type == 'balanced' then + local len_bias_check_pred = function(_, l) + return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias) + end + if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then + initiate_train() + else + lua_util.debugm(N, rspamd_config, + 'cannot learn ANN %s: balanced mode requires more vectors (has %s)', + ann_key, lens) + end + else + -- Probabilistic mode + if min_len > 0 and max_len >= rule.train.max_trains then + initiate_train() + else + lua_util.debugm(N, rspamd_config, + 'cannot learn ANN %s: need min_len > 0 and max_len >= %s (has %s)', + ann_key, rule.train.max_trains, lens) + end + end + end + + -- Callback that adds count from pending key and continues + local function add_pending_cb(cont_cb, what) + return function(err, data) + if not err and (type(data) == 'number' or type(data) == 'string') then + local pending_count = tonumber(data) or 0 + lens[what] = (lens[what] or 0) + pending_count + lua_util.debugm(N, rspamd_config, 'added %s pending %s vectors, total now %s', + pending_count, what, lens[what]) + end + cont_cb() + end + end + + -- Simple callback that just adds versioned count and continues + local function add_versioned_cb(cont_cb, what) + return function(err, data) + if not err and (type(data) == 'number' or type(data) == 'string') then + local count = tonumber(data) or 0 + lens[what] = (lens[what] or 0) + count + lua_util.debugm(N, rspamd_config, 'added %s versioned %s vectors, total now %s', + count, what, lens[what]) + end + cont_cb() + end + end + + -- Check pending ham, then make final decision + local function check_pending_ham() + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, + add_pending_cb(maybe_initiate_train, 'ham'), + 'SCARD', + { pending_key .. '_ham_set' } + ) + end + + -- Check versioned ham, then check pending ham local function check_ham_len() lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, - false, -- is write - redis_len_cb_gen(initiate_train, 'ham', true), --callback - 'SCARD', -- command + false, + add_versioned_cb(check_pending_ham, 'ham'), + 'SCARD', { ann_key .. '_ham_set' } ) end - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_len_cb_gen(check_ham_len, 'spam', false), --callback - 'SCARD', -- command - { ann_key .. '_spam_set' } - ) + -- Check pending spam, then check ham + local function check_pending_spam() + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, + add_pending_cb(check_ham_len, 'spam'), + 'SCARD', + { pending_key .. '_spam_set' } + ) + end + + -- Check versioned spam, then pending spam + local function check_spam_len() + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, + add_versioned_cb(check_pending_spam, 'spam'), + 'SCARD', + { ann_key .. '_spam_set' } + ) + end + + -- Start the chain + check_spam_len() end end @@ -999,7 +1058,9 @@ local function ann_push_vector(task) lua_util.debugm(N, task, 'do not push data for skipped task') return end - if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then + -- Allow manual training via ANN-Train header regardless of allow_local + local manual_train_header = get_ann_train_header(task) + if not settings.allow_local and not manual_train_header and lua_util.is_rspamc_or_controller(task) then lua_util.debugm(N, task, 'do not push data for manual scan') return end