]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Stabilize neural LLM embedding training and cache keys
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 17 Jan 2026 10:46:13 +0000 (10:46 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 17 Jan 2026 10:46:13 +0000 (10:46 +0000)
contrib/neural-embedding-service/README.md
contrib/neural-embedding-service/embedding_service.py
lualib/plugins/neural.lua
lualib/plugins/neural/providers/llm.lua
lualib/redis_scripts/neural_save_unlock.lua
lualib/redis_scripts/neural_train_size.lua
rules/controller/neural.lua
src/plugins/lua/neural.lua

index 5d1635a6d6a4ce53e85f0d541089665ef72f52f1..dacc513a95cd21e9e2f6d138072d13b6865f54ad 100644 (file)
@@ -50,7 +50,7 @@ See the main guide: `doc/neural-llm-embeddings-guide.md`
 |----------|---------|-------------|
 | `EMBEDDING_MODEL` | `BAAI/bge-small-en-v1.5` | FastEmbed model name |
 | `EMBEDDING_PORT` | `8080` | Port to listen on |
-| `EMBEDDING_HOST` | `0.0.0.0` | Host to bind to |
+| `EMBEDDING_HOST` | `0.0.0.0` | Host to bind |
 
 ### Command Line Arguments
 
index 50daceeb28ceb6bc5e125454693f9b36051b3c08..3b3011256ab8ccba64520c67345e589362e32c6b 100644 (file)
@@ -25,7 +25,6 @@ Environment variables:
     EMBEDDING_PORT: Port number (default: 8080)
     EMBEDDING_HOST: Host to bind (default: 0.0.0.0)
 """
-
 import argparse
 import logging
 import os
@@ -56,7 +55,6 @@ model: Optional[TextEmbedding] = None
 model_name: str = DEFAULT_MODEL
 model_dim: int = 0
 
-
 # Request/Response models
 class OllamaEmbeddingRequest(BaseModel):
     """Ollama-compatible embedding request."""
@@ -104,7 +102,6 @@ class HealthResponse(BaseModel):
     dimensions: int
     uptime_seconds: float
 
-
 # Startup time for uptime calculation
 startup_time: float = 0.0
 
@@ -117,17 +114,12 @@ async def lifespan(app: FastAPI):
     logger.info(f"Loading embedding model: {model_name}")
     start = time.time()
 
-    try:
-        model = TextEmbedding(model_name)
-        # Get embedding dimension from a test inference
-        test_embed = list(model.embed(["test"]))[0]
-        model_dim = len(test_embed)
-        elapsed = time.time() - start
-        logger.info(f"Model loaded in {elapsed:.2f}s, dimensions: {model_dim}")
-        startup_time = time.time()
-    except Exception as e:
-        logger.error(f"Failed to load model: {e}")
-        raise
+    model = TextEmbedding(model_name)
+    test_embed = list(model.embed(["test"]))[0]
+    model_dim = len(test_embed)
+    elapsed = time.time() - start
+    logger.info(f"Model loaded in {elapsed:.2f}s, dimensions: {model_dim}")
+    startup_time = time.time()
 
     yield
 
index 5f71db5fc2236877ab10f99d7f39e595df4efd36..919df7f32ca8a82645018233626207fcac591e05 100644 (file)
@@ -588,17 +588,63 @@ local function redis_ann_prefix(rule, settings_name)
 end
 
 -- Compute a stable digest for providers configuration
-local function providers_config_digest(providers_cfg)
+local function providers_config_digest(providers_cfg, rule)
   if not providers_cfg then return nil end
   -- Normalize minimal subset of fields to keep digest stable across equivalent configs
-  local norm = {}
+  local norm = { providers = {} }
+
+  local fusion = rule and rule.fusion or nil
+  if rule then
+    local effective_fusion = {
+      normalization = (fusion and fusion.normalization) or 'none',
+      include_meta = fusion and fusion.include_meta,
+      meta_weight = fusion and fusion.meta_weight,
+      per_provider_pca = fusion and fusion.per_provider_pca,
+    }
+    if effective_fusion.include_meta == nil then
+      effective_fusion.include_meta = true
+    end
+    if effective_fusion.meta_weight == nil then
+      effective_fusion.meta_weight = 1.0
+    end
+    if effective_fusion.per_provider_pca == nil then
+      effective_fusion.per_provider_pca = false
+    end
+    norm.fusion = effective_fusion
+  end
+
+  if rule and rule.max_inputs then
+    norm.max_inputs = rule.max_inputs
+  end
+
+  local gpt_settings = rspamd_config:get_all_opt('gpt') or {}
+
   for i, p in ipairs(providers_cfg) do
-    norm[i] = {
-      type = p.type,
-      name = p.name,
+    local ptype = p.type or p.name or 'unknown'
+    local entry = {
+      type = ptype,
       weight = p.weight or 1.0,
       dim = p.dim,
     }
+
+    if ptype == 'llm' then
+      local llm_type = p.llm_type or p.api or p.backend or gpt_settings.type
+      local model = p.model or gpt_settings.model
+      local max_tokens = p.max_tokens
+      if not max_tokens and gpt_settings.model_parameters and model then
+        local model_cfg = gpt_settings.model_parameters[model] or {}
+        max_tokens = model_cfg.max_completion_tokens or model_cfg.max_tokens
+      end
+      if not max_tokens then
+        max_tokens = gpt_settings.max_tokens
+      end
+
+      entry.llm_type = llm_type
+      entry.model = model
+      entry.max_tokens = max_tokens
+    end
+
+    norm.providers[i] = entry
   end
   return lua_util.table_digest(norm)
 end
@@ -612,7 +658,7 @@ local function collect_features_async(task, rule, profile_or_set, phase, cb)
   local providers_cfg = rule.providers
   if not providers_cfg or #providers_cfg == 0 then
     if rule.disable_symbols_input then
-      cb(nil, { providers = {}, total_dim = 0, digest = providers_config_digest(providers_cfg) })
+      cb(nil, { providers = {}, total_dim = 0, digest = providers_config_digest(providers_cfg, rule) })
       return
     end
     local prov = get_provider('symbols')
@@ -636,7 +682,7 @@ local function collect_features_async(task, rule, profile_or_set, phase, cb)
         cb(#fused > 0 and fused or nil, {
           providers = build_providers_meta(metas) or metas,
           total_dim = #fused,
-          digest = providers_config_digest(providers_cfg),
+          digest = providers_config_digest(providers_cfg, rule),
         })
       end)
       return
@@ -658,7 +704,7 @@ local function collect_features_async(task, rule, profile_or_set, phase, cb)
         providers = build_providers_meta({ meta }) or { meta },
         total_dim = #fused,
         digest = providers_config_digest(
-          providers_cfg)
+          providers_cfg, rule)
       })
     return
   end
@@ -687,7 +733,7 @@ local function collect_features_async(task, rule, profile_or_set, phase, cb)
       local meta = {
         providers = build_providers_meta(metas) or metas,
         total_dim = #fused,
-        digest = providers_config_digest(providers_cfg),
+        digest = providers_config_digest(providers_cfg, rule),
       }
       if #fused == 0 then
         cb(nil, meta)
@@ -1011,7 +1057,7 @@ local function spawn_train(params)
           digest = params.set.digest,
           redis_key = params.set.ann.redis_key,
           version = version,
-          providers_digest = providers_config_digest(params.rule.providers),
+          providers_digest = providers_config_digest(params.rule.providers, params.rule),
         }
 
         local profile_serialized = ucl.to_format(profile, 'json-compact', true)
index 17fc0c9f3e4869b019497f8e09dabf503a11d28b..b256b5314f21322b54af33c4a3bb699a8635de43 100644 (file)
@@ -13,8 +13,12 @@ local llm_common = require "llm_common"
 
 local N = "neural.llm"
 
-local function select_text(task)
-  local input_tbl = llm_common.build_llm_input(task)
+local function select_text(task, llm)
+  local opts = {}
+  if llm and llm.max_tokens then
+    opts.max_tokens = llm.max_tokens
+  end
+  local input_tbl = llm_common.build_llm_input(task, opts)
   return input_tbl
 end
 
@@ -23,6 +27,12 @@ local function compose_llm_settings(pcfg)
   -- Provider identity is pcfg.type=='llm'; backend type is specified via one of these keys
   local llm_type = pcfg.llm_type or pcfg.api or pcfg.backend or gpt_settings.type or 'openai'
   local model = pcfg.model or gpt_settings.model
+  local model_params = gpt_settings.model_parameters or {}
+  local model_cfg = model and model_params[model] or {}
+  local max_tokens = pcfg.max_tokens
+  if not max_tokens then
+    max_tokens = model_cfg.max_completion_tokens or model_cfg.max_tokens or gpt_settings.max_tokens
+  end
   local timeout = pcfg.timeout or gpt_settings.timeout or 2.0
   local url = pcfg.url
   local api_key = pcfg.api_key or gpt_settings.api_key
@@ -38,6 +48,7 @@ local function compose_llm_settings(pcfg)
   return {
     type = llm_type,
     model = model,
+    max_tokens = max_tokens,
     timeout = timeout,
     url = url,
     api_key = api_key,
@@ -53,6 +64,13 @@ local function compose_llm_settings(pcfg)
   }
 end
 
+local function normalize_cache_key_input(input_string)
+  if type(input_string) == 'userdata' then
+    return input_string:str()
+  end
+  return tostring(input_string)
+end
+
 local function extract_embedding(llm_type, parsed)
   if llm_type == 'openai' then
     -- { data = [ { embedding = [...] } ] }
@@ -89,7 +107,7 @@ neural_common.register_provider('llm', {
       end
     end
 
-    local input_tbl = select_text(task)
+    local input_tbl = select_text(task, llm)
     if not input_tbl then
       rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip')
       cont(nil)
@@ -102,8 +120,9 @@ neural_common.register_provider('llm', {
       input_string = input_string .. "\nSubject: " .. input_tbl.subject
     end
 
+    local input_key = normalize_cache_key_input(input_string)
     rspamd_logger.debugm(N, task, 'llm embedding request: model=%s url=%s len=%s', tostring(llm.model), tostring(llm.url),
-      tostring(#tostring(input_string)))
+      tostring(#input_key))
 
     local body
     if llm.type == 'openai' then
@@ -126,11 +145,18 @@ neural_common.register_provider('llm', {
     }, N)
 
     -- Use raw key and allow cache module to hash/shorten it per context
-    local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', input_string)
+    local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', input_key)
 
     local function finish_with_vec(vec)
       if type(vec) == 'table' and #vec > 0 then
-        local meta = { name = pcfg.name or 'llm', type = 'llm', dim = #vec, weight = ctx.weight or 1.0 }
+        local meta = {
+          name = pcfg.name or 'llm',
+          type = 'llm',
+          dim = #vec,
+          weight = ctx.weight or 1.0,
+          model = llm.model,
+          provider = llm.type,
+        }
         rspamd_logger.debugm(N, task, 'llm embedding result: dim=%s', #vec)
         cont(vec, meta)
       else
index 1ce31afa147c7d2b3f92a38f7e069b809a516044..0c7eaecbc348d363effe77fd549a0e4c02eae11b 100644 (file)
@@ -26,6 +26,8 @@ if KEYS[11] and KEYS[11] ~= '' then
 end
 redis.call('HDEL', KEYS[1], 'lock')
 redis.call('HDEL', KEYS[7], 'lock')
+redis.call('HSET', KEYS[7], 'obsolete', '1')
+redis.call('EXPIRE', KEYS[7], 600)
 redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
 -- expire in 10m, to not face race condition with other rspamd replicas refill deleted keys
 redis.call('EXPIRE', KEYS[7] .. '_spam_set', 600)
index 45ad6a931fbbf74f1f450718477ab0e7d6912af8..1c12660af0fdf6f5b68bd567dfa86b8c03212787 100644 (file)
@@ -4,6 +4,10 @@
 -- returns nspam,nham (or nil if locked)
 
 local prefix = KEYS[1]
+local obsolete = redis.call('HGET', prefix, 'obsolete')
+if obsolete then
+  return 'obsolete'
+end
 local locked = redis.call('HGET', prefix, 'lock')
 if locked then
   local host = redis.call('HGET', prefix, 'hostname') or 'unknown'
index 4c1e3da6a20664ff20116d14ccd059e0f30a3831..628e4a62d78aa2346ea9c1ad1e689d5636e92f92 100644 (file)
@@ -25,6 +25,17 @@ local rspamd_logger = require "rspamd_logger"
 local E = {}
 local N = 'neural'
 
+local function get_request_header(task, name)
+  local hdr = task:get_request_header(name)
+  if type(hdr) == 'table' then
+    hdr = hdr[1]
+  end
+  if hdr then
+    return tostring(hdr)
+  end
+  return nil
+end
+
 -- Controller neural plugin
 
 local learn_request_schema = T.table({
@@ -176,7 +187,7 @@ local function handle_learn_message(task, conn)
     return
   end
 
-  local cls = task:get_request_header('ANN-Train') or task:get_request_header('Class')
+  local cls = get_request_header(task, 'ANN-Train') or get_request_header(task, 'Class')
   if not cls then
     conn:send_error(400, 'missing class header (ANN-Train or Class)')
     return
@@ -188,7 +199,7 @@ local function handle_learn_message(task, conn)
     return
   end
 
-  local rule_name = task:get_request_header('Rule') or 'default'
+  local rule_name = get_request_header(task, 'Rule') or 'default'
   local rule = neural_common.settings.rules[rule_name]
   if not rule then
     conn:send_error(400, 'unknown rule')
@@ -275,7 +286,7 @@ local function handle_learn_message(task, conn)
       version = version,
       digest = set.digest,
       distance = 0,
-      providers_digest = neural_common.providers_config_digest(rule.providers),
+      providers_digest = neural_common.providers_config_digest(rule.providers, rule),
     }
 
     local profile_serialized = ucl.to_format(profile, 'json-compact', true)
index 54426aa6916dc1166570ea0158a1276e308cb316..7054ab92328d554bfd9d1a0bf98dcb9885622f44 100644 (file)
@@ -64,7 +64,7 @@ local function new_ann_profile(task, rule, set, version)
     version = version,
     digest = set.digest,
     distance = 0, -- Since we are using our own profile
-    providers_digest = neural_common.providers_config_digest(rule.providers),
+    providers_digest = neural_common.providers_config_digest(rule.providers, rule),
   }
 
   local ucl = require "ucl"
@@ -144,7 +144,7 @@ local function ann_scores_filter(task)
           local spam_threshold = 0
           if rule.spam_score_threshold then
             spam_threshold = rule.spam_score_threshold
-          elseif rule.roc_enabled and not set.ann.roc_thresholds then
+          elseif rule.roc_enabled and set.ann.roc_thresholds then
             spam_threshold = set.ann.roc_thresholds[1]
           end
 
@@ -166,7 +166,7 @@ local function ann_scores_filter(task)
           local ham_threshold = 0
           if rule.ham_score_threshold then
             ham_threshold = rule.ham_score_threshold
-          elseif rule.roc_enabled and not set.ann.roc_thresholds then
+          elseif rule.roc_enabled and set.ann.roc_thresholds then
             ham_threshold = set.ann.roc_thresholds[2]
           end
 
@@ -194,6 +194,17 @@ local function ann_scores_filter(task)
   end
 end
 
+local function get_ann_train_header(task)
+  local hdr = task:get_request_header('ANN-Train')
+  if type(hdr) == 'table' then
+    hdr = hdr[1]
+  end
+  if hdr then
+    return tostring(hdr):lower()
+  end
+  return nil
+end
+
 local function ann_push_task_result(rule, task, verdict, score, set)
   local train_opts = rule.train
   local learn_spam, learn_ham
@@ -202,10 +213,9 @@ local function ann_push_task_result(rule, task, verdict, score, set)
 
   -- First, honor explicit manual training header if present
   do
-    local hdr = task:get_request_header('ANN-Train')
-    if hdr then
-      local hv = tostring(hdr):lower()
-      lua_util.debugm(N, task, 'found ANN-Train header, enable manual train mode', hv)
+    local hv = get_ann_train_header(task)
+    if hv then
+      lua_util.debugm(N, task, 'found ANN-Train header, enable manual train mode: %s', hv)
       if hv == 'spam' then
         learn_spam = true
         manual_train = true