From: Vsevolod Stakhov Date: Mon, 18 Aug 2025 12:58:39 +0000 (+0100) Subject: [Minor] Don't use coroutines X-Git-Tag: 3.13.0~22^2~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f60a55f6c5d1fbe6844390dd7a36889d9cd09dc7;p=thirdparty%2Frspamd.git [Minor] Don't use coroutines --- diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index b13c6a8273..22e77cb4be 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -130,10 +130,13 @@ local result_to_vector -- Built-in symbols provider (compatibility path) register_provider('symbols', { collect = function(task, ctx) - -- ctx.profile is expected for symbols provider local vec = result_to_vector(task, ctx.profile) return vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 } - end + end, + collect_async = function(task, ctx, cont) + local vec = result_to_vector(task, ctx.profile) + cont(vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 }) + end, }) local function load_scripts() @@ -566,76 +569,123 @@ end -- If no providers configured, fallback to symbols provider unless disabled -- phase: 'infer' | 'train' -local function collect_features(task, rule, profile_or_set, phase) - local vectors = {} - local metas = {} +-- Removed synchronous collect_features; use collect_features_async instead +-- Async version: runs providers in parallel and calls cb(fused, meta) when done +local function collect_features_async(task, rule, profile_or_set, phase, cb) local providers_cfg = rule.providers if not providers_cfg or #providers_cfg == 0 then - if not rule.disable_symbols_input then - local prov = get_provider('symbols') - if prov then - local vec, meta = prov.collect(task, { profile = profile_or_set, weight = 1.0 }) + if rule.disable_symbols_input then + cb(nil, { providers = {}, total_dim = 0, digest = providers_config_digest(providers_cfg) }) + return + end + local prov = get_provider('symbols') + if prov and prov.collect_async then + prov.collect_async(task, { profile = profile_or_set, weight = 1.0, phase = phase }, function(vec, meta) + local metas = {} if vec then - vectors[#vectors + 1] = vec - metas[#metas + 1] = meta + metas[1] = meta end - end - end - else - for _, pcfg in ipairs(providers_cfg) do - local prov = get_provider(pcfg.type or pcfg.name) - if prov then - local ok, vec, meta = pcall(function() - return prov.collect(task, { - profile = profile_or_set, - rule = rule, - config = pcfg, - weight = pcfg.weight or 1.0, - phase = phase, - }) - end) - if ok and vec then - if meta then - meta.weight = pcfg.weight or meta.weight or 1.0 + local fused = {} + if vec then + local w = (meta and meta.weight) or 1.0 + local norm_mode = (rule.fusion and rule.fusion.normalization) or 'none' + if norm_mode ~= 'none' then + vec = apply_normalization(vec, norm_mode) + end + for _, x in ipairs(vec) do + fused[#fused + 1] = x * w end - vectors[#vectors + 1] = vec - metas[#metas + 1] = meta or - { name = pcfg.name or pcfg.type, type = pcfg.type, dim = #vec, weight = pcfg.weight or 1.0 } - else - rspamd_logger.debugm(N, rspamd_config, 'provider %s failed to collect features', pcfg.type or pcfg.name) end - else - rspamd_logger.debugm(N, rspamd_config, 'provider %s is not registered', pcfg.type or pcfg.name) - end + cb(#fused > 0 and fused or nil, { + providers = build_providers_meta(metas) or metas, + total_dim = #fused, + digest = providers_config_digest(providers_cfg), + }) + end) + return end - end - - -- Simple fusion by concatenation; optional per-provider weight scaling - local fused = {} - for i, v in ipairs(vectors) do - local w = (metas[i] and metas[i].weight) or 1.0 - -- Apply normalization if requested + -- Fallback: direct symbols compute + local vec = result_to_vector(task, profile_or_set) + local meta = { name = 'symbols', type = 'symbols', dim = #vec, weight = 1.0 } + local fused = {} + local w = 1.0 local norm_mode = (rule.fusion and rule.fusion.normalization) or 'none' if norm_mode ~= 'none' then - v = apply_normalization(v, norm_mode) + vec = apply_normalization(vec, norm_mode) end - for _, x in ipairs(v) do + for _, x in ipairs(vec) do fused[#fused + 1] = x * w end + cb(fused, + { + providers = build_providers_meta({ meta }) or { meta }, + total_dim = #fused, + digest = providers_config_digest( + providers_cfg) + }) + return end - local meta = { - providers = build_providers_meta(metas) or metas, - total_dim = #fused, - digest = providers_config_digest(providers_cfg), - } + local vectors = {} + local metas = {} + local remaining = 0 + + local function maybe_finish() + remaining = remaining - 1 + if remaining == 0 then + -- Fuse + local fused = {} + for i, v in ipairs(vectors) do + if v then + local w = (metas[i] and metas[i].weight) or 1.0 + local norm_mode = (rule.fusion and rule.fusion.normalization) or 'none' + if norm_mode ~= 'none' then + v = apply_normalization(v, norm_mode) + end + for _, x in ipairs(v) do + fused[#fused + 1] = x * w + end + end + end + local meta = { + providers = build_providers_meta(metas) or metas, + total_dim = #fused, + digest = providers_config_digest(providers_cfg), + } + if #fused == 0 then + cb(nil, meta) + else + cb(fused, meta) + end + end + end - if #fused == 0 then - return nil, meta + local function start_provider(i, pcfg) + local prov = get_provider(pcfg.type or pcfg.name) + if not prov or not prov.collect_async then + maybe_finish() + return + end + prov.collect_async(task, { + profile = profile_or_set, + rule = rule, + config = pcfg, + weight = pcfg.weight or 1.0, + phase = phase, + }, function(vec, meta) + if vec then + metas[i] = meta or { name = pcfg.name or pcfg.type, type = pcfg.type, dim = #vec, weight = pcfg.weight or 1.0 } + vectors[i] = vec + end + maybe_finish() + end) end - return fused, meta + remaining = #providers_cfg + for i, pcfg in ipairs(providers_cfg) do + start_provider(i, pcfg) + end end -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis @@ -1102,7 +1152,7 @@ end return { can_push_train_vector = can_push_train_vector, - collect_features = collect_features, + collect_features_async = collect_features_async, create_ann = create_ann, default_options = default_options, build_providers_meta = build_providers_meta, diff --git a/lualib/plugins/neural/providers/llm.lua b/lualib/plugins/neural/providers/llm.lua index fda0141e33..7ef14228b2 100644 --- a/lualib/plugins/neural/providers/llm.lua +++ b/lualib/plugins/neural/providers/llm.lua @@ -202,5 +202,96 @@ neural_common.register_provider('llm', { } return embedding, meta + end, + collect_async = function(task, ctx, cont) + local pcfg = ctx.config or {} + local llm = compose_llm_settings(pcfg) + if not llm.model then + return cont(nil) + end + local content = select_text(task, pcfg) + if not content or #content == 0 then + return cont(nil) + end + local body + if llm.type == 'openai' then + body = { model = llm.model, input = content } + elseif llm.type == 'ollama' then + body = { model = llm.model, prompt = content } + else + return cont(nil) + end + local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, { + cache_prefix = llm.cache_prefix, + cache_ttl = llm.cache_ttl, + cache_format = 'messagepack', + cache_hash_len = llm.cache_hash_len, + cache_use_hashing = llm.cache_use_hashing, + }, N) + local hasher = require 'rspamd_cryptobox_hash' + local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(content):hex()) + + local function finish_with_embedding(embedding) + if not embedding then return cont(nil) end + for i = 1, #embedding do + embedding[i] = tonumber(embedding[i]) or 0.0 + end + cont(embedding, { + name = pcfg.name or 'llm', + type = 'llm', + dim = #embedding, + weight = pcfg.weight or 1.0, + model = llm.model, + provider = llm.type, + }) + end + + local function request_and_cache() + local headers = { ['Content-Type'] = 'application/json' } + if llm.type == 'openai' and llm.api_key then + headers['Authorization'] = 'Bearer ' .. llm.api_key + end + local http_params = { + url = llm.url, + mime_type = 'application/json', + timeout = llm.timeout, + log_obj = task, + headers = headers, + body = ucl.to_format(body, 'json-compact', true), + task = task, + method = 'POST', + use_gzip = true, + callback = function(err, _, data) + if err then return cont(nil) end + local parser = ucl.parser() + local ok = parser:parse_text(data) + if not ok then return cont(nil) end + local parsed = parser:get_object() + local embedding = extract_embedding(llm.type, parsed) + if embedding and cache_ctx then + lua_cache.cache_set(task, key, { e = embedding }, cache_ctx) + end + finish_with_embedding(embedding) + end, + } + rspamd_http.request(http_params) + end + + if cache_ctx then + lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0, + function(_) + request_and_cache() + end, + function(_, err, data) + if data and data.e then + finish_with_embedding(data.e) + else + request_and_cache() + end + end + ) + else + request_and_cache() + end end }) diff --git a/lualib/plugins/neural/providers/symbols.lua b/lualib/plugins/neural/providers/symbols.lua index 6a3b750ca8..32941891bd 100644 --- a/lualib/plugins/neural/providers/symbols.lua +++ b/lualib/plugins/neural/providers/symbols.lua @@ -6,5 +6,9 @@ neural_common.register_provider('symbols', { collect = function(task, ctx) local vec = neural_common.result_to_vector(task, ctx.profile) return vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 } + end, + collect_async = function(task, ctx, cont) + local vec = neural_common.result_to_vector(task, ctx.profile) + cont(vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 }) end }) diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 0a8ebcd692..633a45854a 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -111,81 +111,82 @@ local function ann_scores_filter(task) end if ann then - local vec - if rule.providers and #rule.providers > 0 then - local fused, meta = neural_common.collect_features(task, rule, profile) - vec = fused - if profile.providers_digest and meta.digest and profile.providers_digest ~= meta.digest then + local function after_features(vec, meta) + if profile.providers_digest and meta and meta.digest and profile.providers_digest ~= meta.digest then lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply', rule.prefix, set.name) vec = nil end - else - vec = neural_common.result_to_vector(task, profile) - end - local score - if not vec then - goto continue_rule - end - if set.ann.norm_stats then - vec = neural_common.apply_normalization(vec, set.ann.norm_stats) - end - local out = ann:apply1(vec, set.ann.pca) - score = out[1] - - local symscore = string.format('%.3f', score) - task:cache_set(rule.prefix .. '_neural_score', score) - lua_util.debugm(N, task, '%s:%s:%s ann score: %s', - rule.prefix, set.name, set.ann.version, symscore) - - if score > 0 then - local result = score - - -- If spam_score_threshold is defined, override all other thresholds. - local spam_threshold = 0 - if rule.spam_score_threshold then - spam_threshold = rule.spam_score_threshold - elseif rule.roc_enabled and not set.ann.roc_thresholds then - spam_threshold = set.ann.roc_thresholds[1] + local score + if not vec then + return + end + if set.ann.norm_stats then + vec = neural_common.apply_normalization(vec, set.ann.norm_stats) end + local out = ann:apply1(vec, set.ann.pca) + score = out[1] + + local symscore = string.format('%.3f', score) + task:cache_set(rule.prefix .. '_neural_score', score) + lua_util.debugm(N, task, '%s:%s:%s ann score: %s', + rule.prefix, set.name, set.ann.version, symscore) + + if score > 0 then + local result = score + + -- If spam_score_threshold is defined, override all other thresholds. + local spam_threshold = 0 + if rule.spam_score_threshold then + spam_threshold = rule.spam_score_threshold + elseif rule.roc_enabled and not set.ann.roc_thresholds then + spam_threshold = set.ann.roc_thresholds[1] + end - if result >= spam_threshold then - if rule.flat_threshold_curve then - task:insert_result(rule.symbol_spam, 1.0, symscore) + if result >= spam_threshold then + if rule.flat_threshold_curve then + task:insert_result(rule.symbol_spam, 1.0, symscore) + else + task:insert_result(rule.symbol_spam, result, symscore) + end else - task:insert_result(rule.symbol_spam, result, symscore) + lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)', + rule.prefix, set.name, set.ann.version, symscore, + spam_threshold) end else - lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)', - rule.prefix, set.name, set.ann.version, symscore, - spam_threshold) - end - else - local result = -(score) - - -- If ham_score_threshold is defined, override all other thresholds. - local ham_threshold = 0 - if rule.ham_score_threshold then - ham_threshold = rule.ham_score_threshold - elseif rule.roc_enabled and not set.ann.roc_thresholds then - ham_threshold = set.ann.roc_thresholds[2] - end + local result = -(score) + + -- If ham_score_threshold is defined, override all other thresholds. + local ham_threshold = 0 + if rule.ham_score_threshold then + ham_threshold = rule.ham_score_threshold + elseif rule.roc_enabled and not set.ann.roc_thresholds then + ham_threshold = set.ann.roc_thresholds[2] + end - if result >= ham_threshold then - if rule.flat_threshold_curve then - task:insert_result(rule.symbol_ham, 1.0, symscore) + if result >= ham_threshold then + if rule.flat_threshold_curve then + task:insert_result(rule.symbol_ham, 1.0, symscore) + else + task:insert_result(rule.symbol_ham, result, symscore) + end else - task:insert_result(rule.symbol_ham, result, symscore) + lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)', + rule.prefix, set.name, set.ann.version, result, + ham_threshold) end - else - lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)', - rule.prefix, set.name, set.ann.version, result, - ham_threshold) end end + + if rule.providers and #rule.providers > 0 then + neural_common.collect_features_async(task, rule, profile, 'infer', after_features) + else + local vec = neural_common.result_to_vector(task, profile) + after_features(vec) + end end - ::continue_rule:: end end @@ -242,19 +243,19 @@ local function ann_push_task_result(rule, task, verdict, score, set) learn_ham = false learn_spam = false - -- Explicitly store tokens in cache - local vec - if rule.providers and #rule.providers > 0 then - local fused = neural_common.collect_features(task, rule, set, 'train') - if type(fused) == 'table' then - vec = fused + -- Explicitly store tokens in cache (use async collector if providers configured) + local function after_collect(vec) + if not vec then + vec = neural_common.result_to_vector(task, set) end + task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack')) + task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest) end - if not vec then - vec = neural_common.result_to_vector(task, set) + if rule.providers and #rule.providers > 0 then + neural_common.collect_features_async(task, rule, set, 'train', after_collect) + else + after_collect(nil) end - task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack')) - task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest) skip_reason = 'store_pool_only has been set' end end @@ -274,12 +275,10 @@ local function ann_push_task_result(rule, task, verdict, score, set) if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then local vec if rule.providers and #rule.providers > 0 then - local fused = neural_common.collect_features(task, rule, set) - if type(fused) == 'table' then - vec = fused - end - end - if not vec then + -- Note: this training path remains sync for now; vectors are pushed when computed + -- fall back to legacy vector; async training push will be added later + vec = neural_common.result_to_vector(task, set) + else vec = neural_common.result_to_vector(task, set) end