settings.prefix, plugin_ver, rule.prefix, n, settings_name)
end
+-- Returns a stable key for pending training vectors (version-independent)
+-- Used for batch/manual training to avoid version mismatch issues
+local function pending_train_key(rule, set)
+ return string.format('%s_%s_%s_pending',
+ settings.prefix, rule.prefix, set.name)
+end
+
-- Compute a stable digest for providers configuration
local function providers_config_digest(providers_cfg, rule)
if not providers_cfg then return nil end
-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
local function spawn_train(params)
+ -- Prevent concurrent training
+ if params.set.learning_spawned then
+ lua_util.debugm(N, rspamd_config, 'spawn_train: training already in progress for %s:%s, skipping',
+ params.rule.prefix, params.set.name)
+ return
+ end
+
-- Check training data sanity
-- Now we need to join inputs and create the appropriate test vectors
local n
else
rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
params.rule.prefix, params.set.name, params.set.ann.redis_key)
+
+ -- Clean up pending training keys if they were used
+ if params.pending_key then
+ local function cleanup_cb(cleanup_err)
+ if cleanup_err then
+ lua_util.debugm(N, rspamd_config, 'failed to cleanup pending keys: %s', cleanup_err)
+ else
+ lua_util.debugm(N, rspamd_config, 'cleaned up pending training keys for %s',
+ params.pending_key)
+ end
+ end
+ -- Delete both spam and ham pending sets
+ lua_redis.redis_make_request_taskless(params.ev_base,
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ true, -- is write
+ cleanup_cb,
+ 'DEL',
+ { params.pending_key .. '_spam_set', params.pending_key .. '_ham_set' }
+ )
+ end
end
end
else
local parser = ucl.parser()
local ok, parse_err = parser:parse_text(data, 'msgpack')
- assert(ok, parse_err)
+ if not ok then
+ rspamd_logger.errx(rspamd_config, 'cannot parse training result for ANN %s:%s: %s (data size: %s)',
+ params.rule.prefix, params.set.name, parse_err, #data)
+ lua_redis.redis_make_request_taskless(params.ev_base,
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ true,
+ gen_unlock_cb(params.rule, params.set, params.ann_key),
+ 'HDEL',
+ { params.ann_key, 'lock' }
+ )
+ return
+ end
local parsed = parser:get_object()
local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
local pca_data = parsed.pca_data
-- Deserialise ANN from the child process
- ann_trained = rspamd_kann.load(parsed.ann_data)
+ local loaded_ann = rspamd_kann.load(parsed.ann_data)
local version = (params.set.ann.version or 0) + 1
params.set.ann.version = version
- params.set.ann.ann = ann_trained
+ params.set.ann.ann = loaded_ann
params.set.ann.symbols = params.set.symbols
params.set.ann.redis_key = new_ann_key(params.rule, params.set, version)
load_scripts = load_scripts,
module_config = module_config,
new_ann_key = new_ann_key,
+ pending_train_key = pending_train_key,
providers_config_digest = providers_config_digest,
register_provider = register_provider,
plugin_ver = plugin_ver,
local parser = argparse()
:name "rspamadm classifier_test"
- :description "Learn bayes classifier and evaluate its performance"
+ :description "Learn classifier and evaluate its performance"
:help_description_margin(32)
parser:option "-H --ham"
parser:option "-S --spam"
:description("Spam directory")
:argname("<dir>")
+parser:option "-C --classifier"
+ :description("Classifier type: bayes or llm_embeddings")
+ :argname("<type>")
+ :default('bayes')
parser:flag "-n --no-learning"
:description("Do not learn classifier")
+parser:flag "-T --train-only"
+ :description("Only train, do not evaluate (llm_embeddings only)")
parser:option "--nconns"
:description("Number of parallel connections")
:argname("<N>")
:description("Use specific rspamc path")
:argname("<path>")
:default('rspamc')
-parser:option "-c --cv-fraction"
+parser:option "-f --cv-fraction"
:description("Use specific fraction for cross-validation")
:argname("<fraction>")
:convert(tonumber)
- :default('0.7')
+ :default(0.7)
parser:option "--spam-symbol"
- :description("Use specific spam symbol (instead of BAYES_SPAM)")
+ :description("Use specific spam symbol (auto-detected from classifier type)")
:argname("<symbol>")
- :default('BAYES_SPAM')
parser:option "--ham-symbol"
- :description("Use specific ham symbol (instead of BAYES_HAM)")
+ :description("Use specific ham symbol (auto-detected from classifier type)")
:argname("<symbol>")
- :default('BAYES_HAM')
+parser:option "--train-wait"
+ :description("Seconds to wait after training for neural network (llm_embeddings only, should be > watch_interval)")
+ :argname("<sec>")
+ :convert(tonumber)
+ :default(90)
local opts
out:close()
end
--- Function to train the classifier with given files
-local function train_classifier(files, command)
+-- Function to train the Bayes classifier with given files
+local function train_bayes(files, command)
local fname = os.tmpname()
list_to_file(files, fname)
local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s --files-list=%s",
opts.rspamc, opts.connect, opts.nconns, opts.timeout, command, fname)
+ local handle = assert(io.popen(rspamc_command))
+ handle:read("*all")
+ handle:close()
+ os.remove(fname)
+end
+
+-- Function to train with ANN-Train header (for llm_embeddings/neural)
+-- Uses settings to enable only NEURAL_LEARN symbol, skipping full scan
+local function train_neural(files, learn_type)
+ local fname = os.tmpname()
+ list_to_file(files, fname)
+
+ -- Use ANN-Train header with settings to limit scan to NEURAL_LEARN only
+ local rspamc_command = string.format(
+ "%s --connect %s -j --compact -n %s -t %.3f " ..
+ "--settings '{\"symbols_enabled\":[\"NEURAL_LEARN\"]}' " ..
+ "--header 'ANN-Train=%s' --files-list=%s",
+ opts.rspamc, opts.connect, opts.nconns, opts.timeout,
+ learn_type, fname)
+
local result = assert(io.popen(rspamc_command))
- result = result:read("*all")
+ local output = result:read("*all")
+ result:close()
os.remove(fname)
+
+ -- Count successful submissions
+ local count = 0
+ for line in output:gmatch("[^\n]+") do
+ local ucl_parser = ucl.parser()
+ local is_good, _ = ucl_parser:parse_string(line)
+ if is_good then
+ count = count + 1
+ end
+ end
+
+ return count
end
-- Function to classify files and return results
local fname = os.tmpname()
list_to_file(files, fname)
- local settings_header = string.format('--header Settings=\"{symbols_enabled=[%s, %s]}\"',
+ local settings_header = string.format('--header Settings="{symbols_enabled=[%s, %s]}"',
opts.spam_symbol, opts.ham_symbol)
local rspamc_command = string.format("%s %s --connect %s --compact -n %s -t %.3f --files-list=%s",
opts.rspamc,
local is_good, err = ucl_parser:parse_string(line)
if not is_good then
rspamd_logger.errx("Parser error: %1", err)
- os.remove(fname)
- return nil
- end
- local obj = ucl_parser:get_object()
- local file = obj.filename
- local symbols = obj.symbols or {}
-
- if symbols[opts.spam_symbol] then
- table.insert(results, { result = "spam", file = file })
- if known_ham_files[file] then
- rspamd_logger.message("FP: %s is classified as spam but is known ham", file)
- end
- elseif symbols[opts.ham_symbol] then
- if known_spam_files[file] then
- rspamd_logger.message("FN: %s is classified as ham but is known spam", file)
+ else
+ local obj = ucl_parser:get_object()
+ local file = obj.filename
+ local symbols = obj.symbols or {}
+
+ if symbols[opts.spam_symbol] then
+ local score = symbols[opts.spam_symbol].score
+ table.insert(results, { result = "spam", file = file, score = score })
+ if known_ham_files[file] then
+ rspamd_logger.message("FP: %s is classified as spam but is known ham", file)
+ end
+ elseif symbols[opts.ham_symbol] then
+ local score = symbols[opts.ham_symbol].score
+ table.insert(results, { result = "ham", file = file, score = score })
+ if known_spam_files[file] then
+ rspamd_logger.message("FN: %s is classified as ham but is known spam", file)
+ end
+ else
+ -- No classification result
+ table.insert(results, { result = "unknown", file = file })
end
- table.insert(results, { result = "ham", file = file })
end
end
+ result:close()
os.remove(fname)
return results
-- Function to evaluate classifier performance
local function evaluate_results(results, spam_label, ham_label,
known_spam_files, known_ham_files, total_cv_files, elapsed)
- local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0
+ local true_positives, false_positives, true_negatives, false_negatives = 0, 0, 0, 0
+ local classified, unclassified = 0, 0
+
for _, res in ipairs(results) do
if res.result == spam_label then
if known_spam_files[res.file] then
elseif known_ham_files[res.file] then
false_positives = false_positives + 1
end
- total = total + 1
+ classified = classified + 1
elseif res.result == ham_label then
if known_spam_files[res.file] then
false_negatives = false_negatives + 1
elseif known_ham_files[res.file] then
true_negatives = true_negatives + 1
end
- total = total + 1
+ classified = classified + 1
+ else
+ unclassified = unclassified + 1
end
end
- local accuracy = (true_positives + true_negatives) / total
- local precision = true_positives / (true_positives + false_positives)
- local recall = true_positives / (true_positives + false_negatives)
- local f1_score = 2 * (precision * recall) / (precision + recall)
-
- print(string.format("%-20s %-10s", "Metric", "Value"))
- print(string.rep("-", 30))
+ print(string.format("\n%-20s %-10s", "Metric", "Value"))
+ print(string.rep("-", 35))
print(string.format("%-20s %-10d", "True Positives", true_positives))
print(string.format("%-20s %-10d", "False Positives", false_positives))
print(string.format("%-20s %-10d", "True Negatives", true_negatives))
print(string.format("%-20s %-10d", "False Negatives", false_negatives))
- print(string.format("%-20s %-10.2f", "Accuracy", accuracy))
- print(string.format("%-20s %-10.2f", "Precision", precision))
- print(string.format("%-20s %-10.2f", "Recall", recall))
- print(string.format("%-20s %-10.2f", "F1 Score", f1_score))
- print(string.format("%-20s %-10.2f", "Classified (%)", total / total_cv_files * 100))
- print(string.format("%-20s %-10.2f", "Elapsed time (seconds)", elapsed))
+ print(string.format("%-20s %-10d", "Unclassified", unclassified))
+
+ if classified > 0 then
+ local accuracy = (true_positives + true_negatives) / classified
+ local precision = true_positives > 0 and true_positives / (true_positives + false_positives) or 0
+ local recall = true_positives > 0 and true_positives / (true_positives + false_negatives) or 0
+ local f1_score = (precision + recall) > 0 and 2 * (precision * recall) / (precision + recall) or 0
+
+ print(string.format("%-20s %-10.4f", "Accuracy", accuracy))
+ print(string.format("%-20s %-10.4f", "Precision", precision))
+ print(string.format("%-20s %-10.4f", "Recall", recall))
+ print(string.format("%-20s %-10.4f", "F1 Score", f1_score))
+ end
+
+ print(string.format("%-20s %-10.2f%%", "Classified", classified / total_cv_files * 100))
+ print(string.format("%-20s %-10.2f", "Elapsed (sec)", elapsed))
end
local function handler(args)
opts = parser:parse(args)
+
local ham_directory = opts['ham']
local spam_directory = opts['spam']
+ local classifier_type = opts['classifier']
+
+ if not ham_directory or not spam_directory then
+ print("Error: Both --ham and --spam directories are required")
+ os.exit(1)
+ end
+
+ -- Set default symbols based on classifier type
+ if not opts.spam_symbol then
+ if classifier_type == 'llm_embeddings' then
+ opts.spam_symbol = 'NEURAL_SPAM'
+ else
+ opts.spam_symbol = 'BAYES_SPAM'
+ end
+ end
+ if not opts.ham_symbol then
+ if classifier_type == 'llm_embeddings' then
+ opts.ham_symbol = 'NEURAL_HAM'
+ else
+ opts.ham_symbol = 'BAYES_HAM'
+ end
+ end
+
-- Get all files
local spam_files = get_files(spam_directory)
local known_spam_files = lua_util.list_to_hash(spam_files)
local ham_files = get_files(ham_directory)
local known_ham_files = lua_util.list_to_hash(ham_files)
- -- Split files into training and cross-validation sets
+ print(string.format("Classifier: %s", classifier_type))
+ print(string.format("Found %d spam files, %d ham files", #spam_files, #ham_files))
+ -- Split files into training and cross-validation sets
local train_spam, cv_spam = split_table(spam_files, opts.cv_fraction)
local train_ham, cv_ham = split_table(ham_files, opts.cv_fraction)
- print(string.format("Spam: %d train files, %d cv files; ham: %d train files, %d cv files",
+ print(string.format("Split: %d/%d spam (train/test), %d/%d ham (train/test)",
#train_spam, #cv_spam, #train_ham, #cv_ham))
+
+ -- Training phase
if not opts.no_learning then
- -- Train classifier
+ print("\n=== Training Phase ===")
+
local t, train_spam_time, train_ham_time
- print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns))
- t = rspamd_util.get_time()
- train_classifier(train_spam, "learn_spam")
- train_spam_time = rspamd_util.get_time() - t
- print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns))
- t = rspamd_util.get_time()
- train_classifier(train_ham, "learn_ham")
- train_ham_time = rspamd_util.get_time() - t
- print(string.format("Learning done: %d spam messages in %.2f seconds, %d ham messages in %.2f seconds",
- #train_spam, train_spam_time, #train_ham, train_ham_time))
+
+ if classifier_type == 'llm_embeddings' then
+ -- Neural/LLM training using ANN-Train header
+ -- Interleave spam and ham submissions for balanced training
+ print(string.format("Training %d spam + %d ham messages (interleaved)...", #train_spam, #train_ham))
+ t = rspamd_util.get_time()
+
+ -- Create interleaved list of {file, type} pairs
+ local interleaved = {}
+ local spam_idx, ham_idx = 1, 1
+ while spam_idx <= #train_spam or ham_idx <= #train_ham do
+ if spam_idx <= #train_spam then
+ table.insert(interleaved, { file = train_spam[spam_idx], type = 'spam' })
+ spam_idx = spam_idx + 1
+ end
+ if ham_idx <= #train_ham then
+ table.insert(interleaved, { file = train_ham[ham_idx], type = 'ham' })
+ ham_idx = ham_idx + 1
+ end
+ end
+
+ -- Submit in batches, grouped by type for efficiency
+ local batch_size = math.max(1, math.floor(#interleaved / 10))
+ local spam_batch, ham_batch = {}, {}
+ local spam_trained, ham_trained = 0, 0
+
+ for i, item in ipairs(interleaved) do
+ if item.type == 'spam' then
+ table.insert(spam_batch, item.file)
+ else
+ table.insert(ham_batch, item.file)
+ end
+
+ -- Submit batches periodically
+ if i % batch_size == 0 or i == #interleaved then
+ if #spam_batch > 0 then
+ spam_trained = spam_trained + train_neural(spam_batch, "spam")
+ spam_batch = {}
+ end
+ if #ham_batch > 0 then
+ ham_trained = ham_trained + train_neural(ham_batch, "ham")
+ ham_batch = {}
+ end
+ end
+ end
+
+ train_spam_time = rspamd_util.get_time() - t
+ train_ham_time = 0 -- Combined time
+ print(string.format(" Submitted %d spam + %d ham samples in %.2f seconds",
+ spam_trained, ham_trained, train_spam_time))
+
+ -- Wait for neural network to train using ev_base sleep
+ print(string.format("\nWaiting %d seconds for neural network training...", opts.train_wait))
+ rspamadm_ev_base:sleep(opts.train_wait)
+ print("Training wait complete.")
+ else
+ -- Bayes training using learn_spam/learn_ham
+ print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns))
+ t = rspamd_util.get_time()
+ train_bayes(train_spam, "learn_spam")
+ train_spam_time = rspamd_util.get_time() - t
+
+ print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns))
+ t = rspamd_util.get_time()
+ train_bayes(train_ham, "learn_ham")
+ train_ham_time = rspamd_util.get_time() - t
+
+ print(string.format("Learning done: %d spam in %.2f sec, %d ham in %.2f sec",
+ #train_spam, train_spam_time, #train_ham, train_ham_time))
+ end
+ else
+ print("\nSkipping training phase (--no-learning)")
+ end
+
+ if opts.train_only then
+ print("\nTraining only mode - skipping evaluation")
+ return
end
- -- Classify cross-validation files
+ -- Cross-validation phase
+ print("\n=== Evaluation Phase ===")
+
local cv_files = {}
for _, file in ipairs(cv_spam) do
table.insert(cv_files, file)
-- Shuffle cross-validation files
cv_files = split_table(cv_files, 1)
- print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns))
+ print(string.format("Classifying %d test messages...", #cv_files))
+
-- Get classification results
local t = rspamd_util.get_time()
local results = classify_files(cv_files, known_spam_files, known_ham_files)
known_ham_files,
#cv_files,
elapsed)
-
end
return {
lua_util.debugm(N, task, 'controller.neural: vector size=%s head=[%s]', vlen, vhead)
local compressed = rspamd_util.zstd_compress(table.concat(vec, ';'))
- local target_key = string.format('%s_%s_set', redis_base, learn_type)
+ -- Use pending key for manual training (picked up by training loop)
+ local pending_key = neural_common.pending_train_key(rule, set)
+ local target_key = string.format('%s_%s_set', pending_key, learn_type)
local function learn_vec_cb(redis_err)
if redis_err then
conn:send_error(400, 'unknown rule')
return
end
- -- Trigger check_anns to evaluate training conditions
- rspamd_config:add_periodic(task:get_ev_base(), 0.0, function()
- return 0.0
- end)
- conn:send_ucl({ success = true, message = 'training scheduled check' })
+
+ -- Get the set for this rule
+ local set = neural_common.get_rule_settings(task, rule)
+ if not set then
+ -- Try to find any available set
+ for sid, s in pairs(rule.settings or {}) do
+ if type(s) == 'table' then
+ set = s
+ set.name = set.name or sid
+ break
+ end
+ end
+ end
+
+ if not set then
+ conn:send_error(400, 'no settings found for rule')
+ return
+ end
+
+ -- Check pending vectors count
+ local pending_key = neural_common.pending_train_key(rule, set)
+ local ev_base = task:get_ev_base()
+
+ local function check_and_train(spam_count, ham_count)
+ if spam_count > 0 and ham_count > 0 then
+ -- We have vectors in pending, find or create a profile and train
+ local ann_key
+ if set.ann and set.ann.redis_key then
+ ann_key = set.ann.redis_key
+ else
+ -- Create a new profile
+ ann_key = neural_common.new_ann_key(rule, set, 0)
+ end
+
+ rspamd_logger.infox(task, 'manual train trigger for %s:%s with %s spam, %s ham vectors',
+ rule.prefix, set.name, spam_count, ham_count)
+
+ -- The training will be picked up by the next check_anns cycle
+ -- For immediate training, we'd need access to the worker object
+ conn:send_ucl({
+ success = true,
+ message = 'training vectors available',
+ spam_vectors = spam_count,
+ ham_vectors = ham_count,
+ pending_key = pending_key
+ })
+ else
+ conn:send_ucl({
+ success = false,
+ message = 'not enough vectors for training',
+ spam_vectors = spam_count,
+ ham_vectors = ham_count
+ })
+ end
+ end
+
+ -- Count pending vectors
+ local spam_count = 0
+ local function count_ham_cb(err, data)
+ local ham_count = 0
+ if not err and (type(data) == 'number' or type(data) == 'string') then
+ ham_count = tonumber(data) or 0
+ end
+ check_and_train(spam_count, ham_count)
+ end
+
+ local function count_spam_cb(err, data)
+ if not err and (type(data) == 'number' or type(data) == 'string') then
+ spam_count = tonumber(data) or 0
+ end
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ false,
+ count_ham_cb,
+ 'SCARD',
+ { pending_key .. '_ham_set' }
+ )
+ end
+
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ false,
+ count_spam_cb,
+ 'SCARD',
+ { pending_key .. '_spam_set' }
+ )
end
return {
#include "libutil/str_util.h"
#include "libserver/html/html.h"
#include "libserver/hyperscan_tools.h"
+#include "libserver/async_session.h"
#include "lua_parsers.h"
lua_State *L;
int cbref;
struct thread_entry *thread;
+ struct rspamd_async_session *session;
ev_timer ev;
};
+/* Dummy finalizer for sleep session events - the timer handles cleanup */
+static void
+lua_ev_base_sleep_session_fin(gpointer ud)
+{
+ /* Nothing to do - timer callback handles everything */
+}
+
static void
lua_ev_base_sleep_cb(struct ev_loop *loop, struct ev_timer *t, int events)
{
ev_timer_stop(loop, t);
+ /* Remove session event if we registered one */
+ if (cbdata->session) {
+ rspamd_session_remove_event(cbdata->session,
+ lua_ev_base_sleep_session_fin, cbdata);
+ }
+
if (cbdata->cbref != -1) {
/* Async mode: call the callback */
lua_State *L = cbdata->L;
* @param {number} time timeout in seconds
* @param {function} callback optional callback for async mode
*/
+/* Helper to get rspamadm_session from Lua globals if available */
+static struct rspamd_async_session *
+lua_get_rspamadm_session(lua_State *L)
+{
+ struct rspamd_async_session *session = NULL;
+
+ lua_getglobal(L, "rspamadm_session");
+
+ if (lua_type(L, -1) == LUA_TUSERDATA) {
+ void *ud = rspamd_lua_check_udata_maybe(L, -1, rspamd_session_classname);
+ if (ud) {
+ session = *((struct rspamd_async_session **) ud);
+ }
+ }
+
+ lua_pop(L, 1);
+ return session;
+}
+
static int
lua_ev_base_sleep(lua_State *L)
{
cbdata->ev.data = cbdata;
cbdata->cbref = -1;
cbdata->thread = NULL;
+ cbdata->session = NULL;
+
+ /* Try to get rspamadm_session for session event tracking */
+ struct rspamd_async_session *session = lua_get_rspamadm_session(L);
if (lua_isfunction(L, 3)) {
/* Async mode with callback */
lua_pushvalue(L, 3);
cbdata->cbref = luaL_ref(L, LUA_REGISTRYINDEX);
+ /* Register session event if available so wait_session_events waits for us */
+ if (session) {
+ cbdata->session = session;
+ rspamd_session_add_event(session,
+ lua_ev_base_sleep_session_fin, cbdata, "lua sleep");
+ }
+
ev_timer_init(&cbdata->ev, lua_ev_base_sleep_cb, timeout, 0.0);
ev_timer_start(ev_base, &cbdata->ev);
if (cfg && cfg->lua_thread_pool) {
cbdata->thread = lua_thread_pool_get_running_entry(cfg->lua_thread_pool);
+ /* Register session event if available so wait_session_events waits for us */
+ if (session) {
+ cbdata->session = session;
+ rspamd_session_add_event(session,
+ lua_ev_base_sleep_session_fin, cbdata, "lua sleep");
+ }
+
ev_timer_init(&cbdata->ev, lua_ev_base_sleep_cb, timeout, 0.0);
ev_timer_start(ev_base, &cbdata->ev);
if ann then
local function after_features(vec, meta)
- if profile.providers_digest and meta and meta.digest and profile.providers_digest ~= meta.digest then
- lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply',
- rule.prefix, set.name)
- vec = nil
+ -- For providers-based ANNs, require matching digest
+ -- For symbols-based ANNs (no providers), skip this check
+ local has_providers = rule.providers and #rule.providers > 0
+ if has_providers then
+ local stored_digest = profile.providers_digest
+ local current_digest = meta and meta.digest
+ if not stored_digest then
+ -- Old ANN was trained without providers - needs retraining with current config
+ lua_util.debugm(N, task,
+ 'ANN %s:%s was trained without providers, skipping (retrain with current config)',
+ rule.prefix, set.name)
+ vec = nil
+ elseif stored_digest ~= current_digest then
+ rspamd_logger.warnx(task,
+ 'providers config changed for %s:%s (stored=%s, current=%s), ANN needs retraining',
+ rule.prefix, set.name, stored_digest, current_digest or 'none')
+ vec = nil
+ end
end
local score
end
local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
- local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
+ -- For manual training, use stable pending key to avoid version mismatch
+ local target_key
+ if manual_train then
+ target_key = neural_common.pending_train_key(rule, set) .. '_' .. learn_type .. '_set'
+ else
+ target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
+ end
local function learn_vec_cb(redis_err)
if redis_err then
end
-- Check if we can learn
- if set.can_store_vectors then
+ -- For manual training, bypass can_store_vectors check (it may not be set yet)
+ if set.can_store_vectors or manual_train then
if not set.ann then
-- Need to create or load a profile corresponding to the current configuration
set.ann = new_ann_profile(task, rule, set, 0)
lua_util.debugm(N, task,
- 'requested new profile for %s, set.ann is missing',
- set.name)
+ 'requested new profile for %s, set.ann is missing (manual_train=%s)',
+ set.name, manual_train)
end
lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
-- This function does the following:
-- * Tries to lock ANN
--- * Loads spam and ham vectors
+-- * Loads spam and ham vectors (from versioned key AND pending key)
-- * Spawn learning process
local function do_train_ann(worker, ev_base, rule, set, ann_key)
local spam_elts = {}
local ham_elts = {}
- lua_util.debugm(N, rspamd_config, 'do_train_ann: start for %s:%s key=%s', rule.prefix, set.name, ann_key)
+ local pending_key = neural_common.pending_train_key(rule, set)
+ lua_util.debugm(N, rspamd_config, 'do_train_ann: start for %s:%s key=%s pending=%s',
+ rule.prefix, set.name, ann_key, pending_key)
local function redis_ham_cb(err, data)
if err or type(data) ~= 'table' then
set = set,
ann_key = ann_key,
ham_vec = ham_elts,
- spam_vec = spam_elts
+ spam_vec = spam_elts,
+ pending_key = pending_key
})
end
end
else
-- Decompress and convert to numbers each training vector
spam_elts = process_training_vectors(data)
- -- Now get ham vectors...
+ -- Now get ham vectors from both versioned and pending keys
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false, -- is write
redis_ham_cb, --callback
- 'SMEMBERS', -- command
- { ann_key .. '_ham_set' }
+ 'SUNION', -- command (union of sets)
+ { ann_key .. '_ham_set', pending_key .. '_ham_set' }
)
end
end
ann_key, err)
elseif type(data) == 'number' and data == 1 then
-- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
+ -- Fetch from both versioned key and pending key using SUNION
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
false, -- is write
redis_spam_cb, --callback
- 'SMEMBERS', -- command
- { ann_key .. '_spam_set' }
+ 'SUNION', -- command (union of sets)
+ { ann_key .. '_spam_set', pending_key .. '_spam_set' }
)
- rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
- rule.prefix, set.name, ann_key)
+ rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s, pending %s) for learning',
+ rule.prefix, set.name, ann_key, pending_key)
else
local lock_tm = tonumber(data[1])
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
if sel_elt then
-- We have our ANN and that's train vectors, check if we can learn
local ann_key = sel_elt.redis_key
+ local pending_key = neural_common.pending_train_key(rule, set)
- lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
- ann_key)
-
- -- Create continuation closure
- local redis_len_cb_gen = function(cont_cb, what, is_final)
- return function(err, data)
- if err then
- rspamd_logger.errx(rspamd_config,
- 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
- elseif data and type(data) == 'number' or type(data) == 'string' then
- local ntrains = tonumber(data) or 0
- lens[what] = ntrains
- if is_final then
- -- Ensure that we have the following:
- -- one class has reached max_trains
- -- other class(es) are at least as full as classes_bias
- -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
- -- one class must have 10 or more trains whilst another should have
- -- at least (10 * (1 - 0.25)) = 8 trains
-
- local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
- local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
-
- if rule.train.learn_type == 'balanced' then
- local len_bias_check_pred = function(_, l)
- return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
- end
- if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
- lua_util.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
- cont_cb()
- else
- lua_util.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
- end
- else
- -- Probabilistic mode, just ensure that at least one vector is okay
- if min_len > 0 and max_len >= rule.train.max_trains then
- lua_util.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
- cont_cb()
- else
- lua_util.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
- end
- end
- else
- lua_util.debugm(N, rspamd_config,
- 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
- what, ann_key, ntrains, rule.train.max_trains)
- cont_cb()
- end
- end
- end
- end
+ lua_util.debugm(N, rspamd_config, "check if ANN %s (pending %s) needs to be trained",
+ ann_key, pending_key)
local function initiate_train()
rspamd_logger.infox(rspamd_config,
- 'need to learn ANN %s after %s required learn vectors',
- ann_key, lens)
+ 'need to learn ANN %s (pending %s) after %s required learn vectors',
+ ann_key, pending_key, lens)
lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: initiating train for key=%s spam=%s ham=%s', ann_key,
lens.spam or -1, lens.ham or -1)
do_train_ann(worker, ev_base, rule, set, ann_key)
end
- -- Spam vector is OK, check ham vector length
+ -- Final check after all vectors are counted
+ local function maybe_initiate_train()
+ local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
+ local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
+
+ lua_util.debugm(N, rspamd_config,
+ 'final vector count for ANN %s: spam=%s ham=%s (min=%s max=%s required=%s)',
+ ann_key, lens.spam, lens.ham, min_len, max_len, rule.train.max_trains)
+
+ if rule.train.learn_type == 'balanced' then
+ local len_bias_check_pred = function(_, l)
+ return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
+ end
+ if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
+ initiate_train()
+ else
+ lua_util.debugm(N, rspamd_config,
+ 'cannot learn ANN %s: balanced mode requires more vectors (has %s)',
+ ann_key, lens)
+ end
+ else
+ -- Probabilistic mode
+ if min_len > 0 and max_len >= rule.train.max_trains then
+ initiate_train()
+ else
+ lua_util.debugm(N, rspamd_config,
+ 'cannot learn ANN %s: need min_len > 0 and max_len >= %s (has %s)',
+ ann_key, rule.train.max_trains, lens)
+ end
+ end
+ end
+
+ -- Callback that adds count from pending key and continues
+ local function add_pending_cb(cont_cb, what)
+ return function(err, data)
+ if not err and (type(data) == 'number' or type(data) == 'string') then
+ local pending_count = tonumber(data) or 0
+ lens[what] = (lens[what] or 0) + pending_count
+ lua_util.debugm(N, rspamd_config, 'added %s pending %s vectors, total now %s',
+ pending_count, what, lens[what])
+ end
+ cont_cb()
+ end
+ end
+
+ -- Simple callback that just adds versioned count and continues
+ local function add_versioned_cb(cont_cb, what)
+ return function(err, data)
+ if not err and (type(data) == 'number' or type(data) == 'string') then
+ local count = tonumber(data) or 0
+ lens[what] = (lens[what] or 0) + count
+ lua_util.debugm(N, rspamd_config, 'added %s versioned %s vectors, total now %s',
+ count, what, lens[what])
+ end
+ cont_cb()
+ end
+ end
+
+ -- Check pending ham, then make final decision
+ local function check_pending_ham()
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false,
+ add_pending_cb(maybe_initiate_train, 'ham'),
+ 'SCARD',
+ { pending_key .. '_ham_set' }
+ )
+ end
+
+ -- Check versioned ham, then check pending ham
local function check_ham_len()
lua_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rule.redis,
nil,
- false, -- is write
- redis_len_cb_gen(initiate_train, 'ham', true), --callback
- 'SCARD', -- command
+ false,
+ add_versioned_cb(check_pending_ham, 'ham'),
+ 'SCARD',
{ ann_key .. '_ham_set' }
)
end
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_len_cb_gen(check_ham_len, 'spam', false), --callback
- 'SCARD', -- command
- { ann_key .. '_spam_set' }
- )
+ -- Check pending spam, then check ham
+ local function check_pending_spam()
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false,
+ add_pending_cb(check_ham_len, 'spam'),
+ 'SCARD',
+ { pending_key .. '_spam_set' }
+ )
+ end
+
+ -- Check versioned spam, then pending spam
+ local function check_spam_len()
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false,
+ add_versioned_cb(check_pending_spam, 'spam'),
+ 'SCARD',
+ { ann_key .. '_spam_set' }
+ )
+ end
+
+ -- Start the chain
+ check_spam_len()
end
end
lua_util.debugm(N, task, 'do not push data for skipped task')
return
end
- if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then
+ -- Allow manual training via ANN-Train header regardless of allow_local
+ local manual_train_header = get_ann_train_header(task)
+ if not settings.allow_local and not manual_train_header and lua_util.is_rspamc_or_controller(task) then
lua_util.debugm(N, task, 'do not push data for manual scan')
return
end