From: Vsevolod Stakhov Date: Thu, 28 Aug 2025 12:57:02 +0000 (+0100) Subject: [Project] Fix various other issues X-Git-Tag: 3.13.0~22^2~2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=a1975b95a9447406b475c097d4eda7d2c9cef866;p=thirdparty%2Frspamd.git [Project] Fix various other issues --- diff --git a/lualib/llm_common.lua b/lualib/llm_common.lua index a89aafa438..a254a1fed7 100644 --- a/lualib/llm_common.lua +++ b/lualib/llm_common.lua @@ -7,6 +7,7 @@ local lua_mime = require "lua_mime" local fun = require "fun" local M = {} +local N = 'llm_common' local function get_meta_llm_content(task) local url_content = "Url domains: no urls found" @@ -34,11 +35,13 @@ function M.build_llm_input(task, opts) local sel_part = lua_mime.get_displayed_text_part(task) if not sel_part then + lua_util.debugm(N, task, 'no displayed text part found') return nil, nil end local nwords = sel_part:get_words_count() or 0 if nwords < 5 then + lua_util.debugm(N, task, 'too few words in part: %s', nwords) return nil, sel_part end @@ -51,6 +54,7 @@ function M.build_llm_input(task, opts) else text = table.concat(words, ' ') end + lua_util.debugm(N, task, 'truncated text to %s tokens (had %s words)', max_tokens, nwords) else -- Keep rspamd_text (userdata) intact; consumers (http/ucl) can use it directly text = sel_part:get_content_oneline() or '' diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index 22e77cb4be..17661c1e94 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -682,10 +682,29 @@ local function collect_features_async(task, rule, profile_or_set, phase, cb) end) end - remaining = #providers_cfg + -- Include metatokens as an extra provider when configured + local include_meta = rule.fusion and rule.fusion.include_meta + local meta_weight = (rule.fusion and rule.fusion.meta_weight) or 1.0 + + remaining = #providers_cfg + (include_meta and 1 or 0) for i, pcfg in ipairs(providers_cfg) do start_provider(i, pcfg) end + + if include_meta then + local prov = get_provider('symbols') + if prov and prov.collect_async then + prov.collect_async(task, { profile = profile_or_set, weight = meta_weight, phase = phase }, function(vec, meta) + if vec then + metas[#metas + 1] = { name = 'symbols', type = 'symbols', dim = #vec, weight = meta_weight } + vectors[#vectors + 1] = vec + end + maybe_finish() + end) + else + maybe_finish() + end + end end -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis diff --git a/lualib/plugins/neural/providers/llm.lua b/lualib/plugins/neural/providers/llm.lua index 4f17979c50..8f08fbb57b 100644 --- a/lualib/plugins/neural/providers/llm.lua +++ b/lualib/plugins/neural/providers/llm.lua @@ -20,7 +20,8 @@ end local function compose_llm_settings(pcfg) local gpt_settings = rspamd_config:get_all_opt('gpt') or {} - local llm_type = pcfg.type or gpt_settings.type or 'openai' + -- Provider identity is pcfg.type=='llm'; backend type is specified via one of these keys + local llm_type = pcfg.llm_type or pcfg.api or pcfg.backend or gpt_settings.type or 'openai' local model = pcfg.model or gpt_settings.model local timeout = pcfg.timeout or gpt_settings.timeout or 2.0 local url = pcfg.url @@ -42,8 +43,8 @@ local function compose_llm_settings(pcfg) api_key = api_key, cache_ttl = pcfg.cache_ttl or 86400, cache_prefix = pcfg.cache_prefix or 'neural_llm', - cache_hash_len = pcfg.cache_hash_len or 16, - cache_use_hashing = pcfg.cache_use_hashing ~= false, + cache_hash_len = pcfg.cache_hash_len or 32, + cache_use_hashing = (pcfg.cache_use_hashing ~= false), } end @@ -63,19 +64,21 @@ local function extract_embedding(llm_type, parsed) end neural_common.register_provider('llm', { - collect = function(task, ctx) + collect_async = function(task, ctx, cont) local pcfg = ctx.config or {} local llm = compose_llm_settings(pcfg) if not llm.model then rspamd_logger.debugm(N, task, 'llm provider missing model; skip') - return nil + cont(nil) + return end local input_tbl = select_text(task) if not input_tbl then rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip') - return nil + cont(nil) + return end -- Build request input string (text then optional subject), keeping rspamd_text intact @@ -84,6 +87,9 @@ neural_common.register_provider('llm', { input_string = input_string .. "\nSubject: " .. input_tbl.subject end + rspamd_logger.debugm(N, task, 'llm embedding request: model=%s url=%s len=%s', tostring(llm.model), tostring(llm.url), + tostring(#tostring(input_string))) + local body if llm.type == 'openai' then body = { model = llm.model, input = input_string } @@ -91,10 +97,11 @@ neural_common.register_provider('llm', { body = { model = llm.model, prompt = input_string } else rspamd_logger.debugm(N, task, 'unsupported llm type: %s', llm.type) - return nil + cont(nil) + return end - -- Redis cache: hash the final input string only (IUF is trivial here) + -- Redis cache: hash the final input string only local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, { cache_prefix = llm.cache_prefix, cache_ttl = llm.cache_ttl, @@ -103,8 +110,49 @@ neural_common.register_provider('llm', { cache_use_hashing = llm.cache_use_hashing, }, N) - local hasher = require 'rspamd_cryptobox_hash' - local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(input_string):hex()) + -- Use raw key and allow cache module to hash/shorten it per context + local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', input_string) + + local function finish_with_vec(vec) + if type(vec) == 'table' and #vec > 0 then + local meta = { name = pcfg.name or 'llm', type = 'llm', dim = #vec, weight = ctx.weight or 1.0 } + rspamd_logger.debugm(N, task, 'llm embedding result: dim=%s', #vec) + cont(vec, meta) + else + rspamd_logger.debugm(N, task, 'llm embedding result: empty') + cont(nil) + end + end + + local function http_cb(err, code, resp, _) + if err then + rspamd_logger.debugm(N, task, 'llm http error: %s', err) + cont(nil) + return + end + if code ~= 200 or not resp then + rspamd_logger.debugm(N, task, 'llm bad http code: %s', code) + cont(nil) + return + end + + local parser = ucl.parser() + local ok, perr = parser:parse_string(resp) + if not ok then + rspamd_logger.debugm(N, task, 'llm cannot parse reply: %s', perr) + cont(nil) + return + end + local parsed = parser:get_object() + local emb = extract_embedding(llm.type, parsed) + if type(emb) == 'table' then + lua_cache.cache_set(task, key, emb, cache_ctx) + finish_with_vec(emb) + else + rspamd_logger.debugm(N, task, 'llm embedding parse: no embedding field') + cont(nil) + end + end local function do_request_and_cache() local headers = { ['Content-Type'] = 'application/json' } @@ -122,41 +170,25 @@ neural_common.register_provider('llm', { task = task, method = 'POST', use_gzip = true, + keepalive = true, + callback = http_cb, } - local function http_cb(err, code, resp, _) - if err then - rspamd_logger.debugm(N, task, 'llm http error: %s', err) - return - end - if code ~= 200 or not resp then - rspamd_logger.debugm(N, task, 'llm bad http code: %s', code) - return - end - - local parser = ucl.parser() - local ok, perr = parser:parse_string(resp) - if not ok then - rspamd_logger.debugm(N, task, 'llm cannot parse reply: %s', perr) - return - end - local parsed = parser:get_object() - local emb = extract_embedding(llm.type, parsed) - if type(emb) == 'table' then - cache_ctx:set_cached(key, emb) - neural_common.append_provider_vector(ctx, { provider = 'llm', vector = emb }) - end - end - - rspamd_http.request(http_params, http_cb) - end - - local cached = cache_ctx:get_cached(key) - if type(cached) == 'table' then - neural_common.append_provider_vector(ctx, { provider = 'llm', vector = cached }) - return + rspamd_http.request(http_params) end - do_request_and_cache() + -- Use async cache API + lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0, + function() + -- Uncached path + do_request_and_cache() + end, + function(_, err, data) + if data and type(data) == 'table' then + finish_with_vec(data) + else + do_request_and_cache() + end + end) end, }) diff --git a/rules/controller/neural.lua b/rules/controller/neural.lua index 6e8cd80b58..0aace1cc1d 100644 --- a/rules/controller/neural.lua +++ b/rules/controller/neural.lua @@ -167,6 +167,15 @@ end -- - Rule: rule name (optional, default 'default') local function handle_learn_message(task, conn) lua_util.debugm(N, task, 'controller.neural: learn_message called') + + -- Ensure the message is parsed so LLM providers can access text parts + local ok_parse = task:process_message() + if not ok_parse then + lua_util.debugm(N, task, 'controller.neural: cannot process message MIME, abort') + conn:send_error(400, 'cannot parse message for learning') + return + end + local cls = task:get_request_header('ANN-Train') or task:get_request_header('Class') if not cls then conn:send_error(400, 'missing class header (ANN-Train or Class)') @@ -203,7 +212,34 @@ local function handle_learn_message(task, conn) end local set = neural_common.get_rule_settings(task, rule) - if not set or not set.ann or not set.ann.redis_key then + if not set then + lua_util.debugm(N, task, 'controller.neural: no settings resolved for rule=%s; falling back to first available set', + rule_name) + for sid, s in pairs(rule.settings or {}) do + if type(s) == 'table' then + set = s + set.name = set.name or sid + break + end + end + end + + -- Derive redis base key even if ANN not yet initialized + local redis_base + if set and set.ann and set.ann.redis_key then + redis_base = set.ann.redis_key + elseif set then + local ok, prefix = pcall(neural_common.redis_ann_prefix, rule, set.name) + if ok and prefix then + redis_base = prefix + lua_util.debugm(N, task, 'controller.neural: derived redis base key for rule=%s set=%s -> %s', rule_name, set.name, + redis_base) + end + end + + if not set or not redis_base then + lua_util.debugm(N, task, 'controller.neural: invalid set or redis key for learning; set=%s ann=%s', + tostring(set ~= nil), set and tostring(set.ann ~= nil) or 'nil') conn:send_error(400, 'invalid rule settings for learning') return end @@ -211,7 +247,14 @@ local function handle_learn_message(task, conn) local function after_collect(vec) lua_util.debugm(N, task, 'controller.neural: learn_message after_collect, vector=%s', type(vec)) if not vec then - vec = neural_common.result_to_vector(task, set) + if rule.providers and #rule.providers > 0 then + lua_util.debugm(N, task, + 'controller.neural: no vector from providers; skip training to keep dimensions consistent') + conn:send_error(400, 'no vector collected from providers') + return + else + vec = neural_common.result_to_vector(task, set) + end end if type(vec) ~= 'table' then @@ -219,8 +262,22 @@ local function handle_learn_message(task, conn) return end + -- Preview vector for debugging + local function preview_vector(v) + local n = #v + local limit = math.min(n, 8) + local parts = {} + for i = 1, limit do + parts[#parts + 1] = string.format('%.4f', tonumber(v[i]) or 0) + end + return n, table.concat(parts, ',') + end + + local vlen, vhead = preview_vector(vec) + lua_util.debugm(N, task, 'controller.neural: vector size=%s head=[%s]', vlen, vhead) + local compressed = rspamd_util.zstd_compress(table.concat(vec, ';')) - local target_key = string.format('%s_%s_set', set.ann.redis_key, learn_type) + local target_key = string.format('%s_%s_set', redis_base, learn_type) local function learn_vec_cb(redis_err) if redis_err then