]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Don't use coroutines
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 18 Aug 2025 12:58:39 +0000 (13:58 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 18 Aug 2025 12:58:39 +0000 (13:58 +0100)
lualib/plugins/neural.lua
lualib/plugins/neural/providers/llm.lua
lualib/plugins/neural/providers/symbols.lua
src/plugins/lua/neural.lua

index b13c6a8273c71dff35521ad5d340594ab7c6db57..22e77cb4bea018a98b695d0d1440685312ee1e8f 100644 (file)
@@ -130,10 +130,13 @@ local result_to_vector
 -- Built-in symbols provider (compatibility path)
 register_provider('symbols', {
   collect = function(task, ctx)
-    -- ctx.profile is expected for symbols provider
     local vec = result_to_vector(task, ctx.profile)
     return vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 }
-  end
+  end,
+  collect_async = function(task, ctx, cont)
+    local vec = result_to_vector(task, ctx.profile)
+    cont(vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 })
+  end,
 })
 
 local function load_scripts()
@@ -566,76 +569,123 @@ end
 
 -- If no providers configured, fallback to symbols provider unless disabled
 -- phase: 'infer' | 'train'
-local function collect_features(task, rule, profile_or_set, phase)
-  local vectors = {}
-  local metas = {}
+-- Removed synchronous collect_features; use collect_features_async instead
 
+-- Async version: runs providers in parallel and calls cb(fused, meta) when done
+local function collect_features_async(task, rule, profile_or_set, phase, cb)
   local providers_cfg = rule.providers
   if not providers_cfg or #providers_cfg == 0 then
-    if not rule.disable_symbols_input then
-      local prov = get_provider('symbols')
-      if prov then
-        local vec, meta = prov.collect(task, { profile = profile_or_set, weight = 1.0 })
+    if rule.disable_symbols_input then
+      cb(nil, { providers = {}, total_dim = 0, digest = providers_config_digest(providers_cfg) })
+      return
+    end
+    local prov = get_provider('symbols')
+    if prov and prov.collect_async then
+      prov.collect_async(task, { profile = profile_or_set, weight = 1.0, phase = phase }, function(vec, meta)
+        local metas = {}
         if vec then
-          vectors[#vectors + 1] = vec
-          metas[#metas + 1] = meta
+          metas[1] = meta
         end
-      end
-    end
-  else
-    for _, pcfg in ipairs(providers_cfg) do
-      local prov = get_provider(pcfg.type or pcfg.name)
-      if prov then
-        local ok, vec, meta = pcall(function()
-          return prov.collect(task, {
-            profile = profile_or_set,
-            rule = rule,
-            config = pcfg,
-            weight = pcfg.weight or 1.0,
-            phase = phase,
-          })
-        end)
-        if ok and vec then
-          if meta then
-            meta.weight = pcfg.weight or meta.weight or 1.0
+        local fused = {}
+        if vec then
+          local w = (meta and meta.weight) or 1.0
+          local norm_mode = (rule.fusion and rule.fusion.normalization) or 'none'
+          if norm_mode ~= 'none' then
+            vec = apply_normalization(vec, norm_mode)
+          end
+          for _, x in ipairs(vec) do
+            fused[#fused + 1] = x * w
           end
-          vectors[#vectors + 1] = vec
-          metas[#metas + 1] = meta or
-              { name = pcfg.name or pcfg.type, type = pcfg.type, dim = #vec, weight = pcfg.weight or 1.0 }
-        else
-          rspamd_logger.debugm(N, rspamd_config, 'provider %s failed to collect features', pcfg.type or pcfg.name)
         end
-      else
-        rspamd_logger.debugm(N, rspamd_config, 'provider %s is not registered', pcfg.type or pcfg.name)
-      end
+        cb(#fused > 0 and fused or nil, {
+          providers = build_providers_meta(metas) or metas,
+          total_dim = #fused,
+          digest = providers_config_digest(providers_cfg),
+        })
+      end)
+      return
     end
-  end
-
-  -- Simple fusion by concatenation; optional per-provider weight scaling
-  local fused = {}
-  for i, v in ipairs(vectors) do
-    local w = (metas[i] and metas[i].weight) or 1.0
-    -- Apply normalization if requested
+    -- Fallback: direct symbols compute
+    local vec = result_to_vector(task, profile_or_set)
+    local meta = { name = 'symbols', type = 'symbols', dim = #vec, weight = 1.0 }
+    local fused = {}
+    local w = 1.0
     local norm_mode = (rule.fusion and rule.fusion.normalization) or 'none'
     if norm_mode ~= 'none' then
-      v = apply_normalization(v, norm_mode)
+      vec = apply_normalization(vec, norm_mode)
     end
-    for _, x in ipairs(v) do
+    for _, x in ipairs(vec) do
       fused[#fused + 1] = x * w
     end
+    cb(fused,
+      {
+        providers = build_providers_meta({ meta }) or { meta },
+        total_dim = #fused,
+        digest = providers_config_digest(
+          providers_cfg)
+      })
+    return
   end
 
-  local meta = {
-    providers = build_providers_meta(metas) or metas,
-    total_dim = #fused,
-    digest = providers_config_digest(providers_cfg),
-  }
+  local vectors = {}
+  local metas = {}
+  local remaining = 0
+
+  local function maybe_finish()
+    remaining = remaining - 1
+    if remaining == 0 then
+      -- Fuse
+      local fused = {}
+      for i, v in ipairs(vectors) do
+        if v then
+          local w = (metas[i] and metas[i].weight) or 1.0
+          local norm_mode = (rule.fusion and rule.fusion.normalization) or 'none'
+          if norm_mode ~= 'none' then
+            v = apply_normalization(v, norm_mode)
+          end
+          for _, x in ipairs(v) do
+            fused[#fused + 1] = x * w
+          end
+        end
+      end
+      local meta = {
+        providers = build_providers_meta(metas) or metas,
+        total_dim = #fused,
+        digest = providers_config_digest(providers_cfg),
+      }
+      if #fused == 0 then
+        cb(nil, meta)
+      else
+        cb(fused, meta)
+      end
+    end
+  end
 
-  if #fused == 0 then
-    return nil, meta
+  local function start_provider(i, pcfg)
+    local prov = get_provider(pcfg.type or pcfg.name)
+    if not prov or not prov.collect_async then
+      maybe_finish()
+      return
+    end
+    prov.collect_async(task, {
+      profile = profile_or_set,
+      rule = rule,
+      config = pcfg,
+      weight = pcfg.weight or 1.0,
+      phase = phase,
+    }, function(vec, meta)
+      if vec then
+        metas[i] = meta or { name = pcfg.name or pcfg.type, type = pcfg.type, dim = #vec, weight = pcfg.weight or 1.0 }
+        vectors[i] = vec
+      end
+      maybe_finish()
+    end)
   end
 
-  return fused, meta
+  remaining = #providers_cfg
+  for i, pcfg in ipairs(providers_cfg) do
+    start_provider(i, pcfg)
+  end
 end
 
 -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
@@ -1102,7 +1152,7 @@ end
 
 return {
   can_push_train_vector = can_push_train_vector,
-  collect_features = collect_features,
+  collect_features_async = collect_features_async,
   create_ann = create_ann,
   default_options = default_options,
   build_providers_meta = build_providers_meta,
index fda0141e33c9ba5b8991a39a96d42de9322c20ee..7ef14228b2132dfd35c8965a5767d442fba88899 100644 (file)
@@ -202,5 +202,96 @@ neural_common.register_provider('llm', {
     }
 
     return embedding, meta
+  end,
+  collect_async = function(task, ctx, cont)
+    local pcfg = ctx.config or {}
+    local llm = compose_llm_settings(pcfg)
+    if not llm.model then
+      return cont(nil)
+    end
+    local content = select_text(task, pcfg)
+    if not content or #content == 0 then
+      return cont(nil)
+    end
+    local body
+    if llm.type == 'openai' then
+      body = { model = llm.model, input = content }
+    elseif llm.type == 'ollama' then
+      body = { model = llm.model, prompt = content }
+    else
+      return cont(nil)
+    end
+    local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, {
+      cache_prefix = llm.cache_prefix,
+      cache_ttl = llm.cache_ttl,
+      cache_format = 'messagepack',
+      cache_hash_len = llm.cache_hash_len,
+      cache_use_hashing = llm.cache_use_hashing,
+    }, N)
+    local hasher = require 'rspamd_cryptobox_hash'
+    local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(content):hex())
+
+    local function finish_with_embedding(embedding)
+      if not embedding then return cont(nil) end
+      for i = 1, #embedding do
+        embedding[i] = tonumber(embedding[i]) or 0.0
+      end
+      cont(embedding, {
+        name = pcfg.name or 'llm',
+        type = 'llm',
+        dim = #embedding,
+        weight = pcfg.weight or 1.0,
+        model = llm.model,
+        provider = llm.type,
+      })
+    end
+
+    local function request_and_cache()
+      local headers = { ['Content-Type'] = 'application/json' }
+      if llm.type == 'openai' and llm.api_key then
+        headers['Authorization'] = 'Bearer ' .. llm.api_key
+      end
+      local http_params = {
+        url = llm.url,
+        mime_type = 'application/json',
+        timeout = llm.timeout,
+        log_obj = task,
+        headers = headers,
+        body = ucl.to_format(body, 'json-compact', true),
+        task = task,
+        method = 'POST',
+        use_gzip = true,
+        callback = function(err, _, data)
+          if err then return cont(nil) end
+          local parser = ucl.parser()
+          local ok = parser:parse_text(data)
+          if not ok then return cont(nil) end
+          local parsed = parser:get_object()
+          local embedding = extract_embedding(llm.type, parsed)
+          if embedding and cache_ctx then
+            lua_cache.cache_set(task, key, { e = embedding }, cache_ctx)
+          end
+          finish_with_embedding(embedding)
+        end,
+      }
+      rspamd_http.request(http_params)
+    end
+
+    if cache_ctx then
+      lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0,
+        function(_)
+          request_and_cache()
+        end,
+        function(_, err, data)
+          if data and data.e then
+            finish_with_embedding(data.e)
+          else
+            request_and_cache()
+          end
+        end
+      )
+    else
+      request_and_cache()
+    end
   end
 })
index 6a3b750ca86890de6b86965a2a992fb36124ef80..32941891bd43dcfa3eca655f1f03e9fcf8a3a124 100644 (file)
@@ -6,5 +6,9 @@ neural_common.register_provider('symbols', {
   collect = function(task, ctx)
     local vec = neural_common.result_to_vector(task, ctx.profile)
     return vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 }
+  end,
+  collect_async = function(task, ctx, cont)
+    local vec = neural_common.result_to_vector(task, ctx.profile)
+    cont(vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 })
   end
 })
index 0a8ebcd6926d4d219dcdb5cd061055a333c7eb68..633a45854a539fd9d88ef359f9e40cedc66bfba1 100644 (file)
@@ -111,81 +111,82 @@ local function ann_scores_filter(task)
     end
 
     if ann then
-      local vec
-      if rule.providers and #rule.providers > 0 then
-        local fused, meta = neural_common.collect_features(task, rule, profile)
-        vec = fused
-        if profile.providers_digest and meta.digest and profile.providers_digest ~= meta.digest then
+      local function after_features(vec, meta)
+        if profile.providers_digest and meta and meta.digest and profile.providers_digest ~= meta.digest then
           lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply',
             rule.prefix, set.name)
           vec = nil
         end
-      else
-        vec = neural_common.result_to_vector(task, profile)
-      end
 
-      local score
-      if not vec then
-        goto continue_rule
-      end
-      if set.ann.norm_stats then
-        vec = neural_common.apply_normalization(vec, set.ann.norm_stats)
-      end
-      local out = ann:apply1(vec, set.ann.pca)
-      score = out[1]
-
-      local symscore = string.format('%.3f', score)
-      task:cache_set(rule.prefix .. '_neural_score', score)
-      lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
-        rule.prefix, set.name, set.ann.version, symscore)
-
-      if score > 0 then
-        local result = score
-
-        -- If spam_score_threshold is defined, override all other thresholds.
-        local spam_threshold = 0
-        if rule.spam_score_threshold then
-          spam_threshold = rule.spam_score_threshold
-        elseif rule.roc_enabled and not set.ann.roc_thresholds then
-          spam_threshold = set.ann.roc_thresholds[1]
+        local score
+        if not vec then
+          return
+        end
+        if set.ann.norm_stats then
+          vec = neural_common.apply_normalization(vec, set.ann.norm_stats)
         end
+        local out = ann:apply1(vec, set.ann.pca)
+        score = out[1]
+
+        local symscore = string.format('%.3f', score)
+        task:cache_set(rule.prefix .. '_neural_score', score)
+        lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
+          rule.prefix, set.name, set.ann.version, symscore)
+
+        if score > 0 then
+          local result = score
+
+          -- If spam_score_threshold is defined, override all other thresholds.
+          local spam_threshold = 0
+          if rule.spam_score_threshold then
+            spam_threshold = rule.spam_score_threshold
+          elseif rule.roc_enabled and not set.ann.roc_thresholds then
+            spam_threshold = set.ann.roc_thresholds[1]
+          end
 
-        if result >= spam_threshold then
-          if rule.flat_threshold_curve then
-            task:insert_result(rule.symbol_spam, 1.0, symscore)
+          if result >= spam_threshold then
+            if rule.flat_threshold_curve then
+              task:insert_result(rule.symbol_spam, 1.0, symscore)
+            else
+              task:insert_result(rule.symbol_spam, result, symscore)
+            end
           else
-            task:insert_result(rule.symbol_spam, result, symscore)
+            lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
+              rule.prefix, set.name, set.ann.version, symscore,
+              spam_threshold)
           end
         else
-          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
-            rule.prefix, set.name, set.ann.version, symscore,
-            spam_threshold)
-        end
-      else
-        local result = -(score)
-
-        -- If ham_score_threshold is defined, override all other thresholds.
-        local ham_threshold = 0
-        if rule.ham_score_threshold then
-          ham_threshold = rule.ham_score_threshold
-        elseif rule.roc_enabled and not set.ann.roc_thresholds then
-          ham_threshold = set.ann.roc_thresholds[2]
-        end
+          local result = -(score)
+
+          -- If ham_score_threshold is defined, override all other thresholds.
+          local ham_threshold = 0
+          if rule.ham_score_threshold then
+            ham_threshold = rule.ham_score_threshold
+          elseif rule.roc_enabled and not set.ann.roc_thresholds then
+            ham_threshold = set.ann.roc_thresholds[2]
+          end
 
-        if result >= ham_threshold then
-          if rule.flat_threshold_curve then
-            task:insert_result(rule.symbol_ham, 1.0, symscore)
+          if result >= ham_threshold then
+            if rule.flat_threshold_curve then
+              task:insert_result(rule.symbol_ham, 1.0, symscore)
+            else
+              task:insert_result(rule.symbol_ham, result, symscore)
+            end
           else
-            task:insert_result(rule.symbol_ham, result, symscore)
+            lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
+              rule.prefix, set.name, set.ann.version, result,
+              ham_threshold)
           end
-        else
-          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
-            rule.prefix, set.name, set.ann.version, result,
-            ham_threshold)
         end
       end
+
+      if rule.providers and #rule.providers > 0 then
+        neural_common.collect_features_async(task, rule, profile, 'infer', after_features)
+      else
+        local vec = neural_common.result_to_vector(task, profile)
+        after_features(vec)
+      end
     end
-    ::continue_rule::
   end
 end
 
@@ -242,19 +243,19 @@ local function ann_push_task_result(rule, task, verdict, score, set)
       learn_ham = false
       learn_spam = false
 
-      -- Explicitly store tokens in cache
-      local vec
-      if rule.providers and #rule.providers > 0 then
-        local fused = neural_common.collect_features(task, rule, set, 'train')
-        if type(fused) == 'table' then
-          vec = fused
+      -- Explicitly store tokens in cache (use async collector if providers configured)
+      local function after_collect(vec)
+        if not vec then
+          vec = neural_common.result_to_vector(task, set)
         end
+        task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
+        task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest)
       end
-      if not vec then
-        vec = neural_common.result_to_vector(task, set)
+      if rule.providers and #rule.providers > 0 then
+        neural_common.collect_features_async(task, rule, set, 'train', after_collect)
+      else
+        after_collect(nil)
       end
-      task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
-      task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest)
       skip_reason = 'store_pool_only has been set'
     end
   end
@@ -274,12 +275,10 @@ local function ann_push_task_result(rule, task, verdict, score, set)
         if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
           local vec
           if rule.providers and #rule.providers > 0 then
-            local fused = neural_common.collect_features(task, rule, set)
-            if type(fused) == 'table' then
-              vec = fused
-            end
-          end
-          if not vec then
+            -- Note: this training path remains sync for now; vectors are pushed when computed
+            -- fall back to legacy vector; async training push will be added later
+            vec = neural_common.result_to_vector(task, set)
+          else
             vec = neural_common.result_to_vector(task, set)
           end