settings.prefix, rule.prefix, set.name)
end
+-- Check whether a candidate profile (loaded from the zset) is compatible with
+-- the running rule/set configuration for the purposes of loading the trained
+-- ANN. Compatibility is governed by the vector schema fingerprint:
+--
+-- * has_providers + disable_symbols_input: symbols never enter the input
+-- vector, so providers_digest alone is authoritative. Symbol-list drift
+-- is ignored (dist = 0 when providers_digest matches).
+-- * has_providers (hybrid mode): providers_digest must match (otherwise the
+-- fused vector dimensions differ); symbol drift is tolerated and surfaced
+-- as the returned dist for the caller's tie-breaking.
+-- * pure symbols (no providers): legacy Levenshtein-tolerance — accept when
+-- dist < 30% of |set.symbols|.
+--
+-- Profiles trained with providers are rejected for pure-symbol rules (mixed
+-- vector schemas) and vice versa.
+--
+-- Returns (compatible_bool, dist_number). `dist` is math.huge on rejection.
+local function is_profile_compatible(rule, set, profile_elt, current_providers_digest)
+ if not profile_elt then return false, math.huge end
+ local has_providers = rule.providers and #rule.providers > 0
+
+ if has_providers then
+ if not current_providers_digest or not profile_elt.providers_digest then
+ return false, math.huge
+ end
+ if profile_elt.providers_digest ~= current_providers_digest then
+ return false, math.huge
+ end
+ if rule.disable_symbols_input then
+ return true, 0
+ end
+ local dist = 0
+ if profile_elt.symbols and set.symbols then
+ dist = lua_util.distance_sorted(profile_elt.symbols, set.symbols)
+ end
+ return true, dist
+ end
+
+ -- Pure symbols mode: reject profiles trained with providers (vector schemas
+ -- would be incompatible).
+ if profile_elt.providers_digest then
+ return false, math.huge
+ end
+ if not profile_elt.symbols or not set.symbols then
+ return false, math.huge
+ end
+ local dist = lua_util.distance_sorted(profile_elt.symbols, set.symbols)
+ if dist >= #set.symbols * 0.3 then
+ return false, dist
+ end
+ return true, dist
+end
+
-- Compute a stable digest for providers configuration
local function providers_config_digest(providers_cfg, rule)
if not providers_cfg then return nil end
gen_unlock_cb = gen_unlock_cb,
get_provider = get_provider,
get_rule_settings = get_rule_settings,
+ is_profile_compatible = is_profile_compatible,
load_scripts = load_scripts,
module_config = module_config,
new_ann_key = new_ann_key,
local has_blas = rspamd_tensor.has_blas()
local text_cookie = rspamd_text.cookie
+-- Forward declarations
+local maybe_carryover_ann
+local load_ann_profile
+
-- Creates and stores ANN profile in Redis
local function new_ann_profile(task, rule, set, version)
local ann_key = neural_common.new_ann_key(rule, set, version, settings)
+ local providers_digest = neural_common.providers_config_digest(rule.providers, rule)
local profile = {
symbols = set.symbols,
version = version,
digest = set.digest,
distance = 0, -- Since we are using our own profile
- providers_digest = neural_common.providers_config_digest(rule.providers, rule),
+ providers_digest = providers_digest,
}
local ucl = require "ucl"
else
rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
rule.prefix, set.name, profile.redis_key)
+ -- If a prior profile with the same providers_digest holds trained
+ -- weights, carry them over into the fresh profile key. This prevents
+ -- a symcache-driven profile rotation from abandoning a still-valid
+ -- ANN whenever the input vector schema is decided by providers
+ -- (rather than the symbol list).
+ if providers_digest then
+ maybe_carryover_ann(task, rule, set, ann_key, providers_digest)
+ end
end
end
-- the existing ones.
-- Use this function to load ANNs as `callback` parameter for `check_anns` function
local function process_existing_ann(_, ev_base, rule, set, profiles)
- local my_symbols = set.symbols
+ local has_providers = rule.providers and #rule.providers > 0
+ local current_providers_digest = has_providers and
+ neural_common.providers_config_digest(rule.providers, rule) or nil
local min_diff = math.huge
local sel_elt
- lua_util.debugm(N, rspamd_config, 'process_existing_ann: have %s profiles for %s:%s',
- type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name)
+ lua_util.debugm(N, rspamd_config,
+ 'process_existing_ann: have %s profiles for %s:%s (providers_digest=%s)',
+ type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name,
+ current_providers_digest or 'none')
for _, elt in fun.iter(profiles) do
- if elt and elt.symbols then
- local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
- -- Check distance
- if dist < #my_symbols * .3 then
- -- Prefer profiles with smaller distance, or higher version when distance is equal
- if dist < min_diff or (dist == min_diff and sel_elt and elt.version > sel_elt.version) then
- min_diff = dist
- sel_elt = elt
- end
+ local compatible, dist = neural_common.is_profile_compatible(
+ rule, set, elt, current_providers_digest)
+ if compatible then
+ -- Prefer smaller distance; tie-break on higher version
+ if dist < min_diff
+ or (dist == min_diff and sel_elt and (elt.version or 0) > (sel_elt.version or 0)) then
+ min_diff = dist
+ sel_elt = elt
end
end
end
}
-- We can load element from ANN
if set.ann then
- -- We have an existing ANN, probably the same...
+ -- Providers schema acts as the dominant identity when configured: even
+ -- if the symbol-digest portion drifted (symcache shift), a matching
+ -- providers_digest means the vector shape (and therefore the trained
+ -- weights) are still valid. Reload purely on version freshness in
+ -- that case.
+ local providers_compatible = has_providers and current_providers_digest
+ and set.ann.providers_digest == current_providers_digest
+ and sel_elt.providers_digest == current_providers_digest
+
if set.ann.digest == sel_elt.digest then
-- Same ANN, check version
- if set.ann.version < sel_elt.version then
- -- Load new ann
+ if (set.ann.version or 0) < (sel_elt.version or 0) then
rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
'our version = %s, remote version = %s',
rule.prefix .. ':' .. set.name,
set.ann.version,
sel_elt.version)
end
+ elseif providers_compatible then
+ if (sel_elt.version or 0) > (set.ann.version or 0) then
+ rspamd_logger.infox(rspamd_config,
+ 'providers schema matches for %s; reload newer version %s (ours = %s)',
+ rule.prefix .. ':' .. set.name,
+ sel_elt.version, set.ann.version)
+ load_new_ann(rule, ev_base, set, sel_elt, min_diff)
+ else
+ lua_util.debugm(N, rspamd_config,
+ 'providers schema matches for %s; our version %s >= remote %s, no reload',
+ rule.prefix .. ':' .. set.name,
+ set.ann.version, sel_elt.version)
+ end
else
-- We have some different ANN, so we need to compare distance
- if set.ann.distance > min_diff then
- -- Load more specific ANN
+ if (set.ann.distance or math.huge) > min_diff then
rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
'our distance = %s, remote distance = %s',
rule.prefix .. ':' .. set.name,
-- ANN. By our we mean that it has exactly the same symbols in profile.
-- Use this function to train ANN as `callback` parameter for `check_anns` function
local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
- local my_symbols = set.symbols
+ local has_providers = rule.providers and #rule.providers > 0
+ local current_providers_digest = has_providers and
+ neural_common.providers_config_digest(rule.providers, rule) or nil
local sel_elt
local lens = {
spam = 0,
lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: %s profiles for %s:%s',
type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name)
+ -- Strict match: training data accumulated against an existing profile
+ -- must come from a compatible vector schema. is_profile_compatible
+ -- returns dist=0 when symbols are irrelevant (disable_symbols_input) or
+ -- when symbol-lists actually match.
for _, elt in fun.iter(profiles) do
- if elt and elt.symbols then
- local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
- -- Check distance
- if dist == 0 then
- sel_elt = elt
- break
- end
+ local compatible, dist = neural_common.is_profile_compatible(
+ rule, set, elt, current_providers_digest)
+ if compatible and dist == 0 then
+ sel_elt = elt
+ break
end
end
end
-- Used to deserialise ANN element from a list
-local function load_ann_profile(element)
+load_ann_profile = function(element)
local ucl = require "ucl"
local parser = ucl.parser()
end
end
+-- Async carryover: look up the most recent zset entry with the same
+-- providers_digest and a trained ANN blob, then copy its
+-- ann/roc_thresholds/pca/providers_meta/norm_stats fields into the freshly
+-- created profile's redis_key. Only runs when the new key has no ANN yet,
+-- so this never overwrites a freshly-trained model.
+maybe_carryover_ann = function(task, rule, set, new_key, target_providers_digest)
+ local function zrange_cb(err, data)
+ if err or type(data) ~= 'table' then
+ lua_util.debugm(N, task, 'carryover: cannot read zset %s: %s',
+ set.prefix, err)
+ return
+ end
+
+ local source_key
+ for _, raw in ipairs(data) do
+ local profile = load_ann_profile(raw)
+ if profile
+ and profile.providers_digest == target_providers_digest
+ and profile.redis_key ~= new_key then
+ source_key = profile.redis_key
+ break
+ end
+ end
+
+ if not source_key then
+ lua_util.debugm(N, task,
+ 'carryover: no prior profile with matching providers_digest for %s:%s',
+ rule.prefix, set.name)
+ return
+ end
+
+ local function hmset_cb(hmset_err)
+ if hmset_err then
+ rspamd_logger.errx(task,
+ 'carryover: cannot copy ANN from %s to %s: %s',
+ source_key, new_key, hmset_err)
+ else
+ rspamd_logger.infox(task,
+ 'carryover: copied ANN weights from %s into fresh profile %s ' ..
+ '(providers_digest unchanged)',
+ source_key, new_key)
+ end
+ end
+
+ local function hmget_cb(hmget_err, hmget_data)
+ if hmget_err or type(hmget_data) ~= 'table' then
+ lua_util.debugm(N, task,
+ 'carryover: HMGET error for %s: %s', source_key, hmget_err)
+ return
+ end
+ if not (type(hmget_data[1]) == 'userdata' and hmget_data[1].cookie == text_cookie) then
+ lua_util.debugm(N, task,
+ 'carryover: source key %s has no ANN blob', source_key)
+ return
+ end
+
+ local fields = { 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }
+ local args = { new_key }
+ for i, fname in ipairs(fields) do
+ local v = hmget_data[i]
+ if type(v) == 'userdata' and v.cookie == text_cookie then
+ args[#args + 1] = fname
+ args[#args + 1] = v
+ end
+ end
+
+ if #args <= 1 then
+ lua_util.debugm(N, task,
+ 'carryover: nothing to copy from %s', source_key)
+ return
+ end
+
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ true,
+ hmset_cb,
+ 'HMSET',
+ args)
+ end
+
+ local function exists_cb(hex_err, hex_data)
+ if hex_err then
+ lua_util.debugm(N, task,
+ 'carryover: HEXISTS error for %s: %s', new_key, hex_err)
+ return
+ end
+ if tonumber(hex_data) == 1 then
+ lua_util.debugm(N, task,
+ 'carryover: %s already has an ANN, skipping copy', new_key)
+ return
+ end
+
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ false,
+ hmget_cb,
+ 'HMGET',
+ { source_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' },
+ { opaque_data = true })
+ end
+
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ false,
+ exists_cb,
+ 'HEXISTS',
+ { new_key, 'ann' })
+ end
+
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ false,
+ zrange_cb,
+ 'ZREVRANGE',
+ { set.prefix, '0', tostring(settings.max_profiles) })
+end
+
-- Function to check or load ANNs from Redis
local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
for _, set in pairs(rule.settings) do