]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Fix various other issues
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 28 Aug 2025 12:57:02 +0000 (13:57 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 28 Aug 2025 12:57:02 +0000 (13:57 +0100)
lualib/llm_common.lua
lualib/plugins/neural.lua
lualib/plugins/neural/providers/llm.lua
rules/controller/neural.lua

index a89aafa438ae65aa6c66a90dd6f266957b57b3a3..a254a1fed7faebaa4edf3d97124e5df2314c9e12 100644 (file)
@@ -7,6 +7,7 @@ local lua_mime = require "lua_mime"
 local fun = require "fun"
 
 local M = {}
+local N = 'llm_common'
 
 local function get_meta_llm_content(task)
   local url_content = "Url domains: no urls found"
@@ -34,11 +35,13 @@ function M.build_llm_input(task, opts)
 
   local sel_part = lua_mime.get_displayed_text_part(task)
   if not sel_part then
+    lua_util.debugm(N, task, 'no displayed text part found')
     return nil, nil
   end
 
   local nwords = sel_part:get_words_count() or 0
   if nwords < 5 then
+    lua_util.debugm(N, task, 'too few words in part: %s', nwords)
     return nil, sel_part
   end
 
@@ -51,6 +54,7 @@ function M.build_llm_input(task, opts)
     else
       text = table.concat(words, ' ')
     end
+    lua_util.debugm(N, task, 'truncated text to %s tokens (had %s words)', max_tokens, nwords)
   else
     -- Keep rspamd_text (userdata) intact; consumers (http/ucl) can use it directly
     text = sel_part:get_content_oneline() or ''
index 22e77cb4bea018a98b695d0d1440685312ee1e8f..17661c1e948a9ae3e3e4492432246f2eff9c1cc8 100644 (file)
@@ -682,10 +682,29 @@ local function collect_features_async(task, rule, profile_or_set, phase, cb)
     end)
   end
 
-  remaining = #providers_cfg
+  -- Include metatokens as an extra provider when configured
+  local include_meta = rule.fusion and rule.fusion.include_meta
+  local meta_weight = (rule.fusion and rule.fusion.meta_weight) or 1.0
+
+  remaining = #providers_cfg + (include_meta and 1 or 0)
   for i, pcfg in ipairs(providers_cfg) do
     start_provider(i, pcfg)
   end
+
+  if include_meta then
+    local prov = get_provider('symbols')
+    if prov and prov.collect_async then
+      prov.collect_async(task, { profile = profile_or_set, weight = meta_weight, phase = phase }, function(vec, meta)
+        if vec then
+          metas[#metas + 1] = { name = 'symbols', type = 'symbols', dim = #vec, weight = meta_weight }
+          vectors[#vectors + 1] = vec
+        end
+        maybe_finish()
+      end)
+    else
+      maybe_finish()
+    end
+  end
 end
 
 -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
index 4f17979c50ac453ff190d0c73f5d1ae92163b8af..8f08fbb57b9b58d416558dd25a7fe7d1071685db 100644 (file)
@@ -20,7 +20,8 @@ end
 
 local function compose_llm_settings(pcfg)
   local gpt_settings = rspamd_config:get_all_opt('gpt') or {}
-  local llm_type = pcfg.type or gpt_settings.type or 'openai'
+  -- Provider identity is pcfg.type=='llm'; backend type is specified via one of these keys
+  local llm_type = pcfg.llm_type or pcfg.api or pcfg.backend or gpt_settings.type or 'openai'
   local model = pcfg.model or gpt_settings.model
   local timeout = pcfg.timeout or gpt_settings.timeout or 2.0
   local url = pcfg.url
@@ -42,8 +43,8 @@ local function compose_llm_settings(pcfg)
     api_key = api_key,
     cache_ttl = pcfg.cache_ttl or 86400,
     cache_prefix = pcfg.cache_prefix or 'neural_llm',
-    cache_hash_len = pcfg.cache_hash_len or 16,
-    cache_use_hashing = pcfg.cache_use_hashing ~= false,
+    cache_hash_len = pcfg.cache_hash_len or 32,
+    cache_use_hashing = (pcfg.cache_use_hashing ~= false),
   }
 end
 
@@ -63,19 +64,21 @@ local function extract_embedding(llm_type, parsed)
 end
 
 neural_common.register_provider('llm', {
-  collect = function(task, ctx)
+  collect_async = function(task, ctx, cont)
     local pcfg = ctx.config or {}
     local llm = compose_llm_settings(pcfg)
 
     if not llm.model then
       rspamd_logger.debugm(N, task, 'llm provider missing model; skip')
-      return nil
+      cont(nil)
+      return
     end
 
     local input_tbl = select_text(task)
     if not input_tbl then
       rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip')
-      return nil
+      cont(nil)
+      return
     end
 
     -- Build request input string (text then optional subject), keeping rspamd_text intact
@@ -84,6 +87,9 @@ neural_common.register_provider('llm', {
       input_string = input_string .. "\nSubject: " .. input_tbl.subject
     end
 
+    rspamd_logger.debugm(N, task, 'llm embedding request: model=%s url=%s len=%s', tostring(llm.model), tostring(llm.url),
+      tostring(#tostring(input_string)))
+
     local body
     if llm.type == 'openai' then
       body = { model = llm.model, input = input_string }
@@ -91,10 +97,11 @@ neural_common.register_provider('llm', {
       body = { model = llm.model, prompt = input_string }
     else
       rspamd_logger.debugm(N, task, 'unsupported llm type: %s', llm.type)
-      return nil
+      cont(nil)
+      return
     end
 
-    -- Redis cache: hash the final input string only (IUF is trivial here)
+    -- Redis cache: hash the final input string only
     local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, {
       cache_prefix = llm.cache_prefix,
       cache_ttl = llm.cache_ttl,
@@ -103,8 +110,49 @@ neural_common.register_provider('llm', {
       cache_use_hashing = llm.cache_use_hashing,
     }, N)
 
-    local hasher = require 'rspamd_cryptobox_hash'
-    local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(input_string):hex())
+    -- Use raw key and allow cache module to hash/shorten it per context
+    local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', input_string)
+
+    local function finish_with_vec(vec)
+      if type(vec) == 'table' and #vec > 0 then
+        local meta = { name = pcfg.name or 'llm', type = 'llm', dim = #vec, weight = ctx.weight or 1.0 }
+        rspamd_logger.debugm(N, task, 'llm embedding result: dim=%s', #vec)
+        cont(vec, meta)
+      else
+        rspamd_logger.debugm(N, task, 'llm embedding result: empty')
+        cont(nil)
+      end
+    end
+
+    local function http_cb(err, code, resp, _)
+      if err then
+        rspamd_logger.debugm(N, task, 'llm http error: %s', err)
+        cont(nil)
+        return
+      end
+      if code ~= 200 or not resp then
+        rspamd_logger.debugm(N, task, 'llm bad http code: %s', code)
+        cont(nil)
+        return
+      end
+
+      local parser = ucl.parser()
+      local ok, perr = parser:parse_string(resp)
+      if not ok then
+        rspamd_logger.debugm(N, task, 'llm cannot parse reply: %s', perr)
+        cont(nil)
+        return
+      end
+      local parsed = parser:get_object()
+      local emb = extract_embedding(llm.type, parsed)
+      if type(emb) == 'table' then
+        lua_cache.cache_set(task, key, emb, cache_ctx)
+        finish_with_vec(emb)
+      else
+        rspamd_logger.debugm(N, task, 'llm embedding parse: no embedding field')
+        cont(nil)
+      end
+    end
 
     local function do_request_and_cache()
       local headers = { ['Content-Type'] = 'application/json' }
@@ -122,41 +170,25 @@ neural_common.register_provider('llm', {
         task = task,
         method = 'POST',
         use_gzip = true,
+        keepalive = true,
+        callback = http_cb,
       }
 
-      local function http_cb(err, code, resp, _)
-        if err then
-          rspamd_logger.debugm(N, task, 'llm http error: %s', err)
-          return
-        end
-        if code ~= 200 or not resp then
-          rspamd_logger.debugm(N, task, 'llm bad http code: %s', code)
-          return
-        end
-
-        local parser = ucl.parser()
-        local ok, perr = parser:parse_string(resp)
-        if not ok then
-          rspamd_logger.debugm(N, task, 'llm cannot parse reply: %s', perr)
-          return
-        end
-        local parsed = parser:get_object()
-        local emb = extract_embedding(llm.type, parsed)
-        if type(emb) == 'table' then
-          cache_ctx:set_cached(key, emb)
-          neural_common.append_provider_vector(ctx, { provider = 'llm', vector = emb })
-        end
-      end
-
-      rspamd_http.request(http_params, http_cb)
-    end
-
-    local cached = cache_ctx:get_cached(key)
-    if type(cached) == 'table' then
-      neural_common.append_provider_vector(ctx, { provider = 'llm', vector = cached })
-      return
+      rspamd_http.request(http_params)
     end
 
-    do_request_and_cache()
+    -- Use async cache API
+    lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0,
+      function()
+        -- Uncached path
+        do_request_and_cache()
+      end,
+      function(_, err, data)
+        if data and type(data) == 'table' then
+          finish_with_vec(data)
+        else
+          do_request_and_cache()
+        end
+      end)
   end,
 })
index 6e8cd80b587d21326e96755c0a56bec483b88c60..0aace1cc1d2af2df0bb158b539edff6437e1d17c 100644 (file)
@@ -167,6 +167,15 @@ end
 --  - Rule: rule name (optional, default 'default')
 local function handle_learn_message(task, conn)
   lua_util.debugm(N, task, 'controller.neural: learn_message called')
+
+  -- Ensure the message is parsed so LLM providers can access text parts
+  local ok_parse = task:process_message()
+  if not ok_parse then
+    lua_util.debugm(N, task, 'controller.neural: cannot process message MIME, abort')
+    conn:send_error(400, 'cannot parse message for learning')
+    return
+  end
+
   local cls = task:get_request_header('ANN-Train') or task:get_request_header('Class')
   if not cls then
     conn:send_error(400, 'missing class header (ANN-Train or Class)')
@@ -203,7 +212,34 @@ local function handle_learn_message(task, conn)
   end
 
   local set = neural_common.get_rule_settings(task, rule)
-  if not set or not set.ann or not set.ann.redis_key then
+  if not set then
+    lua_util.debugm(N, task, 'controller.neural: no settings resolved for rule=%s; falling back to first available set',
+      rule_name)
+    for sid, s in pairs(rule.settings or {}) do
+      if type(s) == 'table' then
+        set = s
+        set.name = set.name or sid
+        break
+      end
+    end
+  end
+
+  -- Derive redis base key even if ANN not yet initialized
+  local redis_base
+  if set and set.ann and set.ann.redis_key then
+    redis_base = set.ann.redis_key
+  elseif set then
+    local ok, prefix = pcall(neural_common.redis_ann_prefix, rule, set.name)
+    if ok and prefix then
+      redis_base = prefix
+      lua_util.debugm(N, task, 'controller.neural: derived redis base key for rule=%s set=%s -> %s', rule_name, set.name,
+        redis_base)
+    end
+  end
+
+  if not set or not redis_base then
+    lua_util.debugm(N, task, 'controller.neural: invalid set or redis key for learning; set=%s ann=%s',
+      tostring(set ~= nil), set and tostring(set.ann ~= nil) or 'nil')
     conn:send_error(400, 'invalid rule settings for learning')
     return
   end
@@ -211,7 +247,14 @@ local function handle_learn_message(task, conn)
   local function after_collect(vec)
     lua_util.debugm(N, task, 'controller.neural: learn_message after_collect, vector=%s', type(vec))
     if not vec then
-      vec = neural_common.result_to_vector(task, set)
+      if rule.providers and #rule.providers > 0 then
+        lua_util.debugm(N, task,
+          'controller.neural: no vector from providers; skip training to keep dimensions consistent')
+        conn:send_error(400, 'no vector collected from providers')
+        return
+      else
+        vec = neural_common.result_to_vector(task, set)
+      end
     end
 
     if type(vec) ~= 'table' then
@@ -219,8 +262,22 @@ local function handle_learn_message(task, conn)
       return
     end
 
+    -- Preview vector for debugging
+    local function preview_vector(v)
+      local n = #v
+      local limit = math.min(n, 8)
+      local parts = {}
+      for i = 1, limit do
+        parts[#parts + 1] = string.format('%.4f', tonumber(v[i]) or 0)
+      end
+      return n, table.concat(parts, ',')
+    end
+
+    local vlen, vhead = preview_vector(vec)
+    lua_util.debugm(N, task, 'controller.neural: vector size=%s head=[%s]', vlen, vhead)
+
     local compressed = rspamd_util.zstd_compress(table.concat(vec, ';'))
-    local target_key = string.format('%s_%s_set', set.ann.redis_key, learn_type)
+    local target_key = string.format('%s_%s_set', redis_base, learn_type)
 
     local function learn_vec_cb(redis_err)
       if redis_err then