From: Vsevolod Stakhov Date: Sat, 16 May 2026 19:03:12 +0000 (+0100) Subject: [Fix] neural: preserve trained ANN across symcache-driven profile rotation X-Git-Tag: 4.1.0~55^2~1 X-Git-Url: http://git.ipfire.org/gitweb/index.cgi?a=commitdiff_plain;h=24933a9aaf1df58f46cb3bd4b46a69e28f684d83;p=thirdparty%2Frspamd.git [Fix] neural: preserve trained ANN across symcache-driven profile rotation When rspamd's symbol cache shifts (any added/removed symbol, even unrelated to the neural rule), the per-rule symbol digest changes and the plugin historically picked a brand-new profile — abandoning the previously-trained ANN at the old redis_key. In deployments where the input vector is built from providers (e.g. fasttext_embed conv1d) and `disable_symbols_input` is set, the symbol list is irrelevant to the vector schema, so the rotation needlessly reset inference until enough new training data accumulated. Make providers_digest the authoritative schema fingerprint when providers are configured: * New helper `is_profile_compatible` in lualib/plugins/neural.lua decides load eligibility based on providers_digest first; symbol-list drift is ignored entirely when `disable_symbols_input = true`, and tolerated without bound for hybrid (providers + symbols) rules where symbols form only a minor slice of the fused vector. Pure-symbols rules keep the legacy 30% Levenshtein tolerance and now also reject profiles that were trained with providers (vector schemas differ). * process_existing_ann/maybe_train_existing_ann use the new helper, and the reload decision in process_existing_ann picks the fresher version when the providers schema matches across a symbol-digest shift. * new_ann_profile triggers an async carryover after ZADD: ZREVRANGE the zset, find the most recent prior profile with a matching providers_digest, HMGET its ann/roc_thresholds/pca/providers_meta/ norm_stats, and HMSET them into the fresh redis_key. Gated on HEXISTS new_key ann == 0 so a freshly-trained model is never overwritten. --- diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index 3ba3799da3..358ad080a0 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -717,6 +717,59 @@ local function pending_train_key(rule, set) settings.prefix, rule.prefix, set.name) end +-- Check whether a candidate profile (loaded from the zset) is compatible with +-- the running rule/set configuration for the purposes of loading the trained +-- ANN. Compatibility is governed by the vector schema fingerprint: +-- +-- * has_providers + disable_symbols_input: symbols never enter the input +-- vector, so providers_digest alone is authoritative. Symbol-list drift +-- is ignored (dist = 0 when providers_digest matches). +-- * has_providers (hybrid mode): providers_digest must match (otherwise the +-- fused vector dimensions differ); symbol drift is tolerated and surfaced +-- as the returned dist for the caller's tie-breaking. +-- * pure symbols (no providers): legacy Levenshtein-tolerance — accept when +-- dist < 30% of |set.symbols|. +-- +-- Profiles trained with providers are rejected for pure-symbol rules (mixed +-- vector schemas) and vice versa. +-- +-- Returns (compatible_bool, dist_number). `dist` is math.huge on rejection. +local function is_profile_compatible(rule, set, profile_elt, current_providers_digest) + if not profile_elt then return false, math.huge end + local has_providers = rule.providers and #rule.providers > 0 + + if has_providers then + if not current_providers_digest or not profile_elt.providers_digest then + return false, math.huge + end + if profile_elt.providers_digest ~= current_providers_digest then + return false, math.huge + end + if rule.disable_symbols_input then + return true, 0 + end + local dist = 0 + if profile_elt.symbols and set.symbols then + dist = lua_util.distance_sorted(profile_elt.symbols, set.symbols) + end + return true, dist + end + + -- Pure symbols mode: reject profiles trained with providers (vector schemas + -- would be incompatible). + if profile_elt.providers_digest then + return false, math.huge + end + if not profile_elt.symbols or not set.symbols then + return false, math.huge + end + local dist = lua_util.distance_sorted(profile_elt.symbols, set.symbols) + if dist >= #set.symbols * 0.3 then + return false, dist + end + return true, dist +end + -- Compute a stable digest for providers configuration local function providers_config_digest(providers_cfg, rule) if not providers_cfg then return nil end @@ -1495,6 +1548,7 @@ return { gen_unlock_cb = gen_unlock_cb, get_provider = get_provider, get_rule_settings = get_rule_settings, + is_profile_compatible = is_profile_compatible, load_scripts = load_scripts, module_config = module_config, new_ann_key = new_ann_key, diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index de9bdb9bc6..c6a00c4fa2 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -57,9 +57,14 @@ end local has_blas = rspamd_tensor.has_blas() local text_cookie = rspamd_text.cookie +-- Forward declarations +local maybe_carryover_ann +local load_ann_profile + -- Creates and stores ANN profile in Redis local function new_ann_profile(task, rule, set, version) local ann_key = neural_common.new_ann_key(rule, set, version, settings) + local providers_digest = neural_common.providers_config_digest(rule.providers, rule) local profile = { symbols = set.symbols, @@ -67,7 +72,7 @@ local function new_ann_profile(task, rule, set, version) version = version, digest = set.digest, distance = 0, -- Since we are using our own profile - providers_digest = neural_common.providers_config_digest(rule.providers, rule), + providers_digest = providers_digest, } local ucl = require "ucl" @@ -80,6 +85,14 @@ local function new_ann_profile(task, rule, set, version) else rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s', rule.prefix, set.name, profile.redis_key) + -- If a prior profile with the same providers_digest holds trained + -- weights, carry them over into the fresh profile key. This prevents + -- a symcache-driven profile rotation from abandoning a still-valid + -- ANN whenever the input vector schema is decided by providers + -- (rather than the symbol list). + if providers_digest then + maybe_carryover_ann(task, rule, set, ann_key, providers_digest) + end end end @@ -925,22 +938,25 @@ end -- the existing ones. -- Use this function to load ANNs as `callback` parameter for `check_anns` function local function process_existing_ann(_, ev_base, rule, set, profiles) - local my_symbols = set.symbols + local has_providers = rule.providers and #rule.providers > 0 + local current_providers_digest = has_providers and + neural_common.providers_config_digest(rule.providers, rule) or nil local min_diff = math.huge local sel_elt - lua_util.debugm(N, rspamd_config, 'process_existing_ann: have %s profiles for %s:%s', - type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name) + lua_util.debugm(N, rspamd_config, + 'process_existing_ann: have %s profiles for %s:%s (providers_digest=%s)', + type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name, + current_providers_digest or 'none') for _, elt in fun.iter(profiles) do - if elt and elt.symbols then - local dist = lua_util.distance_sorted(elt.symbols, my_symbols) - -- Check distance - if dist < #my_symbols * .3 then - -- Prefer profiles with smaller distance, or higher version when distance is equal - if dist < min_diff or (dist == min_diff and sel_elt and elt.version > sel_elt.version) then - min_diff = dist - sel_elt = elt - end + local compatible, dist = neural_common.is_profile_compatible( + rule, set, elt, current_providers_digest) + if compatible then + -- Prefer smaller distance; tie-break on higher version + if dist < min_diff + or (dist == min_diff and sel_elt and (elt.version or 0) > (sel_elt.version or 0)) then + min_diff = dist + sel_elt = elt end end end @@ -961,11 +977,18 @@ local function process_existing_ann(_, ev_base, rule, set, profiles) } -- We can load element from ANN if set.ann then - -- We have an existing ANN, probably the same... + -- Providers schema acts as the dominant identity when configured: even + -- if the symbol-digest portion drifted (symcache shift), a matching + -- providers_digest means the vector shape (and therefore the trained + -- weights) are still valid. Reload purely on version freshness in + -- that case. + local providers_compatible = has_providers and current_providers_digest + and set.ann.providers_digest == current_providers_digest + and sel_elt.providers_digest == current_providers_digest + if set.ann.digest == sel_elt.digest then -- Same ANN, check version - if set.ann.version < sel_elt.version then - -- Load new ann + if (set.ann.version or 0) < (sel_elt.version or 0) then rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' .. 'our version = %s, remote version = %s', rule.prefix .. ':' .. set.name, @@ -979,10 +1002,22 @@ local function process_existing_ann(_, ev_base, rule, set, profiles) set.ann.version, sel_elt.version) end + elseif providers_compatible then + if (sel_elt.version or 0) > (set.ann.version or 0) then + rspamd_logger.infox(rspamd_config, + 'providers schema matches for %s; reload newer version %s (ours = %s)', + rule.prefix .. ':' .. set.name, + sel_elt.version, set.ann.version) + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + else + lua_util.debugm(N, rspamd_config, + 'providers schema matches for %s; our version %s >= remote %s, no reload', + rule.prefix .. ':' .. set.name, + set.ann.version, sel_elt.version) + end else -- We have some different ANN, so we need to compare distance - if set.ann.distance > min_diff then - -- Load more specific ANN + if (set.ann.distance or math.huge) > min_diff then rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' .. 'our distance = %s, remote distance = %s', rule.prefix .. ':' .. set.name, @@ -1015,7 +1050,9 @@ end -- ANN. By our we mean that it has exactly the same symbols in profile. -- Use this function to train ANN as `callback` parameter for `check_anns` function local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) - local my_symbols = set.symbols + local has_providers = rule.providers and #rule.providers > 0 + local current_providers_digest = has_providers and + neural_common.providers_config_digest(rule.providers, rule) or nil local sel_elt local lens = { spam = 0, @@ -1024,14 +1061,16 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: %s profiles for %s:%s', type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name) + -- Strict match: training data accumulated against an existing profile + -- must come from a compatible vector schema. is_profile_compatible + -- returns dist=0 when symbols are irrelevant (disable_symbols_input) or + -- when symbol-lists actually match. for _, elt in fun.iter(profiles) do - if elt and elt.symbols then - local dist = lua_util.distance_sorted(elt.symbols, my_symbols) - -- Check distance - if dist == 0 then - sel_elt = elt - break - end + local compatible, dist = neural_common.is_profile_compatible( + rule, set, elt, current_providers_digest) + if compatible and dist == 0 then + sel_elt = elt + break end end @@ -1175,7 +1214,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) end -- Used to deserialise ANN element from a list -local function load_ann_profile(element) +load_ann_profile = function(element) local ucl = require "ucl" local parser = ucl.parser() @@ -1196,6 +1235,127 @@ local function load_ann_profile(element) end end +-- Async carryover: look up the most recent zset entry with the same +-- providers_digest and a trained ANN blob, then copy its +-- ann/roc_thresholds/pca/providers_meta/norm_stats fields into the freshly +-- created profile's redis_key. Only runs when the new key has no ANN yet, +-- so this never overwrites a freshly-trained model. +maybe_carryover_ann = function(task, rule, set, new_key, target_providers_digest) + local function zrange_cb(err, data) + if err or type(data) ~= 'table' then + lua_util.debugm(N, task, 'carryover: cannot read zset %s: %s', + set.prefix, err) + return + end + + local source_key + for _, raw in ipairs(data) do + local profile = load_ann_profile(raw) + if profile + and profile.providers_digest == target_providers_digest + and profile.redis_key ~= new_key then + source_key = profile.redis_key + break + end + end + + if not source_key then + lua_util.debugm(N, task, + 'carryover: no prior profile with matching providers_digest for %s:%s', + rule.prefix, set.name) + return + end + + local function hmset_cb(hmset_err) + if hmset_err then + rspamd_logger.errx(task, + 'carryover: cannot copy ANN from %s to %s: %s', + source_key, new_key, hmset_err) + else + rspamd_logger.infox(task, + 'carryover: copied ANN weights from %s into fresh profile %s ' .. + '(providers_digest unchanged)', + source_key, new_key) + end + end + + local function hmget_cb(hmget_err, hmget_data) + if hmget_err or type(hmget_data) ~= 'table' then + lua_util.debugm(N, task, + 'carryover: HMGET error for %s: %s', source_key, hmget_err) + return + end + if not (type(hmget_data[1]) == 'userdata' and hmget_data[1].cookie == text_cookie) then + lua_util.debugm(N, task, + 'carryover: source key %s has no ANN blob', source_key) + return + end + + local fields = { 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' } + local args = { new_key } + for i, fname in ipairs(fields) do + local v = hmget_data[i] + if type(v) == 'userdata' and v.cookie == text_cookie then + args[#args + 1] = fname + args[#args + 1] = v + end + end + + if #args <= 1 then + lua_util.debugm(N, task, + 'carryover: nothing to copy from %s', source_key) + return + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + true, + hmset_cb, + 'HMSET', + args) + end + + local function exists_cb(hex_err, hex_data) + if hex_err then + lua_util.debugm(N, task, + 'carryover: HEXISTS error for %s: %s', new_key, hex_err) + return + end + if tonumber(hex_data) == 1 then + lua_util.debugm(N, task, + 'carryover: %s already has an ANN, skipping copy', new_key) + return + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + false, + hmget_cb, + 'HMGET', + { source_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, + { opaque_data = true }) + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + false, + exists_cb, + 'HEXISTS', + { new_key, 'ann' }) + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + false, + zrange_cb, + 'ZREVRANGE', + { set.prefix, '0', tostring(settings.max_profiles) }) +end + -- Function to check or load ANNs from Redis local function check_anns(worker, cfg, ev_base, rule, process_callback, what) for _, set in pairs(rule.settings) do