]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] neural: preserve trained ANN across symcache-driven profile rotation
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 16 May 2026 19:03:12 +0000 (20:03 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 16 May 2026 19:03:12 +0000 (20:03 +0100)
When rspamd's symbol cache shifts (any added/removed symbol, even unrelated
to the neural rule), the per-rule symbol digest changes and the plugin
historically picked a brand-new profile — abandoning the previously-trained
ANN at the old redis_key.  In deployments where the input vector is built
from providers (e.g. fasttext_embed conv1d) and `disable_symbols_input` is
set, the symbol list is irrelevant to the vector schema, so the
rotation needlessly reset inference until enough new training data
accumulated.

Make providers_digest the authoritative schema fingerprint when providers
are configured:

* New helper `is_profile_compatible` in lualib/plugins/neural.lua decides
  load eligibility based on providers_digest first; symbol-list drift is
  ignored entirely when `disable_symbols_input = true`, and tolerated
  without bound for hybrid (providers + symbols) rules where symbols form
  only a minor slice of the fused vector.  Pure-symbols rules keep the
  legacy 30% Levenshtein tolerance and now also reject profiles that were
  trained with providers (vector schemas differ).

* process_existing_ann/maybe_train_existing_ann use the new helper, and
  the reload decision in process_existing_ann picks the fresher version
  when the providers schema matches across a symbol-digest shift.

* new_ann_profile triggers an async carryover after ZADD: ZREVRANGE the
  zset, find the most recent prior profile with a matching
  providers_digest, HMGET its ann/roc_thresholds/pca/providers_meta/
  norm_stats, and HMSET them into the fresh redis_key.  Gated on
  HEXISTS new_key ann == 0 so a freshly-trained model is never
  overwritten.

lualib/plugins/neural.lua
src/plugins/lua/neural.lua

index 3ba3799da3a0d6d13b3ac2ee5a336aa568ea0b79..358ad080a0a7b773300afb635bf2ed0b7ab3dbcb 100644 (file)
@@ -717,6 +717,59 @@ local function pending_train_key(rule, set)
     settings.prefix, rule.prefix, set.name)
 end
 
+-- Check whether a candidate profile (loaded from the zset) is compatible with
+-- the running rule/set configuration for the purposes of loading the trained
+-- ANN.  Compatibility is governed by the vector schema fingerprint:
+--
+--   * has_providers + disable_symbols_input: symbols never enter the input
+--     vector, so providers_digest alone is authoritative. Symbol-list drift
+--     is ignored (dist = 0 when providers_digest matches).
+--   * has_providers (hybrid mode): providers_digest must match (otherwise the
+--     fused vector dimensions differ); symbol drift is tolerated and surfaced
+--     as the returned dist for the caller's tie-breaking.
+--   * pure symbols (no providers): legacy Levenshtein-tolerance — accept when
+--     dist < 30% of |set.symbols|.
+--
+-- Profiles trained with providers are rejected for pure-symbol rules (mixed
+-- vector schemas) and vice versa.
+--
+-- Returns (compatible_bool, dist_number).  `dist` is math.huge on rejection.
+local function is_profile_compatible(rule, set, profile_elt, current_providers_digest)
+  if not profile_elt then return false, math.huge end
+  local has_providers = rule.providers and #rule.providers > 0
+
+  if has_providers then
+    if not current_providers_digest or not profile_elt.providers_digest then
+      return false, math.huge
+    end
+    if profile_elt.providers_digest ~= current_providers_digest then
+      return false, math.huge
+    end
+    if rule.disable_symbols_input then
+      return true, 0
+    end
+    local dist = 0
+    if profile_elt.symbols and set.symbols then
+      dist = lua_util.distance_sorted(profile_elt.symbols, set.symbols)
+    end
+    return true, dist
+  end
+
+  -- Pure symbols mode: reject profiles trained with providers (vector schemas
+  -- would be incompatible).
+  if profile_elt.providers_digest then
+    return false, math.huge
+  end
+  if not profile_elt.symbols or not set.symbols then
+    return false, math.huge
+  end
+  local dist = lua_util.distance_sorted(profile_elt.symbols, set.symbols)
+  if dist >= #set.symbols * 0.3 then
+    return false, dist
+  end
+  return true, dist
+end
+
 -- Compute a stable digest for providers configuration
 local function providers_config_digest(providers_cfg, rule)
   if not providers_cfg then return nil end
@@ -1495,6 +1548,7 @@ return {
   gen_unlock_cb = gen_unlock_cb,
   get_provider = get_provider,
   get_rule_settings = get_rule_settings,
+  is_profile_compatible = is_profile_compatible,
   load_scripts = load_scripts,
   module_config = module_config,
   new_ann_key = new_ann_key,
index de9bdb9bc62e3f0d99c64a7798ed11968fd3864d..c6a00c4fa2d7195f059b7fef60e8e906a4b6b1bd 100644 (file)
@@ -57,9 +57,14 @@ end
 local has_blas = rspamd_tensor.has_blas()
 local text_cookie = rspamd_text.cookie
 
+-- Forward declarations
+local maybe_carryover_ann
+local load_ann_profile
+
 -- Creates and stores ANN profile in Redis
 local function new_ann_profile(task, rule, set, version)
   local ann_key = neural_common.new_ann_key(rule, set, version, settings)
+  local providers_digest = neural_common.providers_config_digest(rule.providers, rule)
 
   local profile = {
     symbols = set.symbols,
@@ -67,7 +72,7 @@ local function new_ann_profile(task, rule, set, version)
     version = version,
     digest = set.digest,
     distance = 0, -- Since we are using our own profile
-    providers_digest = neural_common.providers_config_digest(rule.providers, rule),
+    providers_digest = providers_digest,
   }
 
   local ucl = require "ucl"
@@ -80,6 +85,14 @@ local function new_ann_profile(task, rule, set, version)
     else
       rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
         rule.prefix, set.name, profile.redis_key)
+      -- If a prior profile with the same providers_digest holds trained
+      -- weights, carry them over into the fresh profile key.  This prevents
+      -- a symcache-driven profile rotation from abandoning a still-valid
+      -- ANN whenever the input vector schema is decided by providers
+      -- (rather than the symbol list).
+      if providers_digest then
+        maybe_carryover_ann(task, rule, set, ann_key, providers_digest)
+      end
     end
   end
 
@@ -925,22 +938,25 @@ end
 -- the existing ones.
 -- Use this function to load ANNs as `callback` parameter for `check_anns` function
 local function process_existing_ann(_, ev_base, rule, set, profiles)
-  local my_symbols = set.symbols
+  local has_providers = rule.providers and #rule.providers > 0
+  local current_providers_digest = has_providers and
+      neural_common.providers_config_digest(rule.providers, rule) or nil
   local min_diff = math.huge
   local sel_elt
-  lua_util.debugm(N, rspamd_config, 'process_existing_ann: have %s profiles for %s:%s',
-    type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name)
+  lua_util.debugm(N, rspamd_config,
+    'process_existing_ann: have %s profiles for %s:%s (providers_digest=%s)',
+    type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name,
+    current_providers_digest or 'none')
 
   for _, elt in fun.iter(profiles) do
-    if elt and elt.symbols then
-      local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
-      -- Check distance
-      if dist < #my_symbols * .3 then
-        -- Prefer profiles with smaller distance, or higher version when distance is equal
-        if dist < min_diff or (dist == min_diff and sel_elt and elt.version > sel_elt.version) then
-          min_diff = dist
-          sel_elt = elt
-        end
+    local compatible, dist = neural_common.is_profile_compatible(
+      rule, set, elt, current_providers_digest)
+    if compatible then
+      -- Prefer smaller distance; tie-break on higher version
+      if dist < min_diff
+          or (dist == min_diff and sel_elt and (elt.version or 0) > (sel_elt.version or 0)) then
+        min_diff = dist
+        sel_elt = elt
       end
     end
   end
@@ -961,11 +977,18 @@ local function process_existing_ann(_, ev_base, rule, set, profiles)
     }
     -- We can load element from ANN
     if set.ann then
-      -- We have an existing ANN, probably the same...
+      -- Providers schema acts as the dominant identity when configured: even
+      -- if the symbol-digest portion drifted (symcache shift), a matching
+      -- providers_digest means the vector shape (and therefore the trained
+      -- weights) are still valid.  Reload purely on version freshness in
+      -- that case.
+      local providers_compatible = has_providers and current_providers_digest
+          and set.ann.providers_digest == current_providers_digest
+          and sel_elt.providers_digest == current_providers_digest
+
       if set.ann.digest == sel_elt.digest then
         -- Same ANN, check version
-        if set.ann.version < sel_elt.version then
-          -- Load new ann
+        if (set.ann.version or 0) < (sel_elt.version or 0) then
           rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
             'our version = %s, remote version = %s',
             rule.prefix .. ':' .. set.name,
@@ -979,10 +1002,22 @@ local function process_existing_ann(_, ev_base, rule, set, profiles)
             set.ann.version,
             sel_elt.version)
         end
+      elseif providers_compatible then
+        if (sel_elt.version or 0) > (set.ann.version or 0) then
+          rspamd_logger.infox(rspamd_config,
+            'providers schema matches for %s; reload newer version %s (ours = %s)',
+            rule.prefix .. ':' .. set.name,
+            sel_elt.version, set.ann.version)
+          load_new_ann(rule, ev_base, set, sel_elt, min_diff)
+        else
+          lua_util.debugm(N, rspamd_config,
+            'providers schema matches for %s; our version %s >= remote %s, no reload',
+            rule.prefix .. ':' .. set.name,
+            set.ann.version, sel_elt.version)
+        end
       else
         -- We have some different ANN, so we need to compare distance
-        if set.ann.distance > min_diff then
-          -- Load more specific ANN
+        if (set.ann.distance or math.huge) > min_diff then
           rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
             'our distance = %s, remote distance = %s',
             rule.prefix .. ':' .. set.name,
@@ -1015,7 +1050,9 @@ end
 -- ANN. By our we mean that it has exactly the same symbols in profile.
 -- Use this function to train ANN as `callback` parameter for `check_anns` function
 local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
-  local my_symbols = set.symbols
+  local has_providers = rule.providers and #rule.providers > 0
+  local current_providers_digest = has_providers and
+      neural_common.providers_config_digest(rule.providers, rule) or nil
   local sel_elt
   local lens = {
     spam = 0,
@@ -1024,14 +1061,16 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
   lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: %s profiles for %s:%s',
     type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name)
 
+  -- Strict match: training data accumulated against an existing profile
+  -- must come from a compatible vector schema.  is_profile_compatible
+  -- returns dist=0 when symbols are irrelevant (disable_symbols_input) or
+  -- when symbol-lists actually match.
   for _, elt in fun.iter(profiles) do
-    if elt and elt.symbols then
-      local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
-      -- Check distance
-      if dist == 0 then
-        sel_elt = elt
-        break
-      end
+    local compatible, dist = neural_common.is_profile_compatible(
+      rule, set, elt, current_providers_digest)
+    if compatible and dist == 0 then
+      sel_elt = elt
+      break
     end
   end
 
@@ -1175,7 +1214,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
 end
 
 -- Used to deserialise ANN element from a list
-local function load_ann_profile(element)
+load_ann_profile = function(element)
   local ucl = require "ucl"
 
   local parser = ucl.parser()
@@ -1196,6 +1235,127 @@ local function load_ann_profile(element)
   end
 end
 
+-- Async carryover: look up the most recent zset entry with the same
+-- providers_digest and a trained ANN blob, then copy its
+-- ann/roc_thresholds/pca/providers_meta/norm_stats fields into the freshly
+-- created profile's redis_key.  Only runs when the new key has no ANN yet,
+-- so this never overwrites a freshly-trained model.
+maybe_carryover_ann = function(task, rule, set, new_key, target_providers_digest)
+  local function zrange_cb(err, data)
+    if err or type(data) ~= 'table' then
+      lua_util.debugm(N, task, 'carryover: cannot read zset %s: %s',
+        set.prefix, err)
+      return
+    end
+
+    local source_key
+    for _, raw in ipairs(data) do
+      local profile = load_ann_profile(raw)
+      if profile
+          and profile.providers_digest == target_providers_digest
+          and profile.redis_key ~= new_key then
+        source_key = profile.redis_key
+        break
+      end
+    end
+
+    if not source_key then
+      lua_util.debugm(N, task,
+        'carryover: no prior profile with matching providers_digest for %s:%s',
+        rule.prefix, set.name)
+      return
+    end
+
+    local function hmset_cb(hmset_err)
+      if hmset_err then
+        rspamd_logger.errx(task,
+          'carryover: cannot copy ANN from %s to %s: %s',
+          source_key, new_key, hmset_err)
+      else
+        rspamd_logger.infox(task,
+          'carryover: copied ANN weights from %s into fresh profile %s ' ..
+          '(providers_digest unchanged)',
+          source_key, new_key)
+      end
+    end
+
+    local function hmget_cb(hmget_err, hmget_data)
+      if hmget_err or type(hmget_data) ~= 'table' then
+        lua_util.debugm(N, task,
+          'carryover: HMGET error for %s: %s', source_key, hmget_err)
+        return
+      end
+      if not (type(hmget_data[1]) == 'userdata' and hmget_data[1].cookie == text_cookie) then
+        lua_util.debugm(N, task,
+          'carryover: source key %s has no ANN blob', source_key)
+        return
+      end
+
+      local fields = { 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }
+      local args = { new_key }
+      for i, fname in ipairs(fields) do
+        local v = hmget_data[i]
+        if type(v) == 'userdata' and v.cookie == text_cookie then
+          args[#args + 1] = fname
+          args[#args + 1] = v
+        end
+      end
+
+      if #args <= 1 then
+        lua_util.debugm(N, task,
+          'carryover: nothing to copy from %s', source_key)
+        return
+      end
+
+      lua_redis.redis_make_request(task,
+        rule.redis,
+        nil,
+        true,
+        hmset_cb,
+        'HMSET',
+        args)
+    end
+
+    local function exists_cb(hex_err, hex_data)
+      if hex_err then
+        lua_util.debugm(N, task,
+          'carryover: HEXISTS error for %s: %s', new_key, hex_err)
+        return
+      end
+      if tonumber(hex_data) == 1 then
+        lua_util.debugm(N, task,
+          'carryover: %s already has an ANN, skipping copy', new_key)
+        return
+      end
+
+      lua_redis.redis_make_request(task,
+        rule.redis,
+        nil,
+        false,
+        hmget_cb,
+        'HMGET',
+        { source_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' },
+        { opaque_data = true })
+    end
+
+    lua_redis.redis_make_request(task,
+      rule.redis,
+      nil,
+      false,
+      exists_cb,
+      'HEXISTS',
+      { new_key, 'ann' })
+  end
+
+  lua_redis.redis_make_request(task,
+    rule.redis,
+    nil,
+    false,
+    zrange_cb,
+    'ZREVRANGE',
+    { set.prefix, '0', tostring(settings.max_profiles) })
+end
+
 -- Function to check or load ANNs from Redis
 local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
   for _, set in pairs(rule.settings) do