]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Rework rspamc to allow training of different neural types
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 28 Aug 2025 11:32:16 +0000 (12:32 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 28 Aug 2025 11:32:16 +0000 (12:32 +0100)
lualib/plugins/neural/providers/llm.lua
rules/controller/neural.lua
src/client/rspamc.cxx
src/client/rspamdclient.c
src/plugins/lua/neural.lua
test/functional/lib/rspamd.robot

index 33301e9084cdf5fee6a18d90a01ed0bf98a5c0c5..4f17979c50ac453ff190d0c73f5d1ae92163b8af 100644 (file)
@@ -14,8 +14,8 @@ local llm_common = require "llm_common"
 local N = "neural.llm"
 
 local function select_text(task)
-  local content = llm_common.build_llm_input(task)
-  return content
+  local input_tbl = llm_common.build_llm_input(task)
+  return input_tbl
 end
 
 local function compose_llm_settings(pcfg)
@@ -72,23 +72,29 @@ neural_common.register_provider('llm', {
       return nil
     end
 
-    local content = select_text(task)
-    if not content or #content == 0 then
+    local input_tbl = select_text(task)
+    if not input_tbl then
       rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip')
       return nil
     end
 
+    -- Build request input string (text then optional subject), keeping rspamd_text intact
+    local input_string = input_tbl.text or ''
+    if input_tbl.subject and input_tbl.subject ~= '' then
+      input_string = input_string .. "\nSubject: " .. input_tbl.subject
+    end
+
     local body
     if llm.type == 'openai' then
-      body = { model = llm.model, input = content }
+      body = { model = llm.model, input = input_string }
     elseif llm.type == 'ollama' then
-      body = { model = llm.model, prompt = content }
+      body = { model = llm.model, prompt = input_string }
     else
       rspamd_logger.debugm(N, task, 'unsupported llm type: %s', llm.type)
       return nil
     end
 
-    -- Redis cache: use content hash + model + provider as key
+    -- Redis cache: hash the final input string only (IUF is trivial here)
     local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, {
       cache_prefix = llm.cache_prefix,
       cache_ttl = llm.cache_ttl,
@@ -97,9 +103,8 @@ neural_common.register_provider('llm', {
       cache_use_hashing = llm.cache_use_hashing,
     }, N)
 
-    -- Use a stable key based on content digest
     local hasher = require 'rspamd_cryptobox_hash'
-    local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(content):hex())
+    local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(input_string):hex())
 
     local function do_request_and_cache()
       local headers = { ['Content-Type'] = 'application/json' }
@@ -119,161 +124,39 @@ neural_common.register_provider('llm', {
         use_gzip = true,
       }
 
-      local err, data = rspamd_http.request(http_params)
-      if err then
-        rspamd_logger.debugm(N, task, 'llm request failed: %s', err)
-        return nil
-      end
-
-      local parser = ucl.parser()
-      local ok, perr = parser:parse_string(data.content)
-      if not ok then
-        rspamd_logger.debugm(N, task, 'cannot parse llm response: %s', perr)
-        return nil
-      end
-
-      local parsed = parser:get_object()
-      local embedding = extract_embedding(llm.type, parsed)
-      if not embedding or #embedding == 0 then
-        rspamd_logger.debugm(N, task, 'no embedding in llm response')
-        return nil
-      end
-
-      for i = 1, #embedding do
-        embedding[i] = tonumber(embedding[i]) or 0.0
-      end
-
-      lua_cache.cache_set(task, key, { e = embedding }, cache_ctx)
-      return embedding
-    end
+      local function http_cb(err, code, resp, _)
+        if err then
+          rspamd_logger.debugm(N, task, 'llm http error: %s', err)
+          return
+        end
+        if code ~= 200 or not resp then
+          rspamd_logger.debugm(N, task, 'llm bad http code: %s', code)
+          return
+        end
 
-    -- Try cache first
-    local cached_result
-    local done = false
-    lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0,
-      function(_)
-        -- Uncached: perform request synchronously and store
-        cached_result = do_request_and_cache()
-        done = true
-      end,
-      function(_, err, data)
-        if data and data.e then
-          cached_result = data.e
+        local parser = ucl.parser()
+        local ok, perr = parser:parse_string(resp)
+        if not ok then
+          rspamd_logger.debugm(N, task, 'llm cannot parse reply: %s', perr)
+          return
+        end
+        local parsed = parser:get_object()
+        local emb = extract_embedding(llm.type, parsed)
+        if type(emb) == 'table' then
+          cache_ctx:set_cached(key, emb)
+          neural_common.append_provider_vector(ctx, { provider = 'llm', vector = emb })
         end
-        done = true
       end
-    )
 
-    if not done then
-      -- Fallback: ensure we still do the request now (cache API is async-ready, but we need sync path)
-      cached_result = do_request_and_cache()
+      rspamd_http.request(http_params, http_cb)
     end
 
-    local embedding = cached_result
-    if not embedding then
-      return nil
+    local cached = cache_ctx:get_cached(key)
+    if type(cached) == 'table' then
+      neural_common.append_provider_vector(ctx, { provider = 'llm', vector = cached })
+      return
     end
 
-    local meta = {
-      name = pcfg.name or 'llm',
-      type = 'llm',
-      dim = #embedding,
-      weight = pcfg.weight or 1.0,
-      model = llm.model,
-      provider = llm.type,
-    }
-
-    return embedding, meta
+    do_request_and_cache()
   end,
-  collect_async = function(task, ctx, cont)
-    local pcfg = ctx.config or {}
-    local llm = compose_llm_settings(pcfg)
-    if not llm.model then
-      return cont(nil)
-    end
-    local content = select_text(task)
-    if not content or #content == 0 then
-      return cont(nil)
-    end
-    local body
-    if llm.type == 'openai' then
-      body = { model = llm.model, input = content }
-    elseif llm.type == 'ollama' then
-      body = { model = llm.model, prompt = content }
-    else
-      return cont(nil)
-    end
-    local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, {
-      cache_prefix = llm.cache_prefix,
-      cache_ttl = llm.cache_ttl,
-      cache_format = 'messagepack',
-      cache_hash_len = llm.cache_hash_len,
-      cache_use_hashing = llm.cache_use_hashing,
-    }, N)
-    local hasher = require 'rspamd_cryptobox_hash'
-    local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(content):hex())
-
-    local function finish_with_embedding(embedding)
-      if not embedding then return cont(nil) end
-      for i = 1, #embedding do
-        embedding[i] = tonumber(embedding[i]) or 0.0
-      end
-      cont(embedding, {
-        name = pcfg.name or 'llm',
-        type = 'llm',
-        dim = #embedding,
-        weight = pcfg.weight or 1.0,
-        model = llm.model,
-        provider = llm.type,
-      })
-    end
-
-    local function request_and_cache()
-      local headers = { ['Content-Type'] = 'application/json' }
-      if llm.type == 'openai' and llm.api_key then
-        headers['Authorization'] = 'Bearer ' .. llm.api_key
-      end
-      local http_params = {
-        url = llm.url,
-        mime_type = 'application/json',
-        timeout = llm.timeout,
-        log_obj = task,
-        headers = headers,
-        body = ucl.to_format(body, 'json-compact', true),
-        task = task,
-        method = 'POST',
-        use_gzip = true,
-        callback = function(err, _, data)
-          if err then return cont(nil) end
-          local parser = ucl.parser()
-          local ok = parser:parse_text(data)
-          if not ok then return cont(nil) end
-          local parsed = parser:get_object()
-          local embedding = extract_embedding(llm.type, parsed)
-          if embedding and cache_ctx then
-            lua_cache.cache_set(task, key, { e = embedding }, cache_ctx)
-          end
-          finish_with_embedding(embedding)
-        end,
-      }
-      rspamd_http.request(http_params)
-    end
-
-    if cache_ctx then
-      lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0,
-        function(_)
-          request_and_cache()
-        end,
-        function(_, err, data)
-          if data and data.e then
-            finish_with_embedding(data.e)
-          else
-            request_and_cache()
-          end
-        end
-      )
-    else
-      request_and_cache()
-    end
-  end
 })
index 0a55eebeebdabdb5db4c09f49e0ca1a95f417d1f..6e8cd80b587d21326e96755c0a56bec483b88c60 100644 (file)
@@ -18,8 +18,12 @@ local neural_common = require "plugins/neural"
 local ts = require("tableshape").types
 local ucl = require "ucl"
 local lua_util = require "lua_util"
+local rspamd_util = require "rspamd_util"
+local lua_redis = require "lua_redis"
+local rspamd_logger = require "rspamd_logger"
 
 local E = {}
+local N = 'neural'
 
 -- Controller neural plugin
 
@@ -30,6 +34,7 @@ local learn_request_schema = ts.shape {
 }
 
 local function handle_learn(task, conn)
+  lua_util.debugm(N, task, 'controller.neural: learn called')
   local parser = ucl.parser()
   local ok, err = parser:parse_text(task:get_rawbody())
   if not ok then
@@ -59,10 +64,12 @@ local function handle_learn(task, conn)
     worker = task:get_worker(),
   }
 
+  lua_util.debugm(N, task, 'controller.neural: learn scheduled for rule=%s', rule_name)
   conn:send_string('{"success" : true}')
 end
 
 local function handle_status(task, conn, req_params)
+  lua_util.debugm(N, task, 'controller.neural: status called')
   local out = {
     rules = {},
   }
@@ -72,7 +79,20 @@ local function handle_status(task, conn, req_params)
       fusion = rule.fusion,
       max_inputs = rule.max_inputs,
       settings = {},
+      requires_scan = false,
     }
+    -- Default: if no providers configured, assume symbols (full scan required)
+    local has_providers = type(rule.providers) == 'table' and #rule.providers > 0
+    if not has_providers then
+      r.requires_scan = true
+    else
+      for _, p in ipairs(rule.providers) do
+        if p.type == 'symbols' then
+          r.requires_scan = true
+          break
+        end
+      end
+    end
     for sid, set in pairs(rule.settings or {}) do
       if type(set) == 'table' then
         local s = {
@@ -95,6 +115,144 @@ local function handle_status(task, conn, req_params)
   conn:send_ucl({ success = true, data = out })
 end
 
+-- Return compact configuration for clients (e.g. rspamc) to plan learning
+local function handle_config(task, conn, req_params)
+  lua_util.debugm(N, task, 'controller.neural: config called')
+  local out = {
+    rules = {},
+  }
+
+  for name, rule in pairs(neural_common.settings.rules) do
+    local requires_scan = false
+    local has_providers = type(rule.providers) == 'table' and #rule.providers > 0
+    if not has_providers then
+      requires_scan = true
+    else
+      for _, p in ipairs(rule.providers) do
+        if p.type == 'symbols' then
+          requires_scan = true
+          break
+        end
+      end
+    end
+
+    local r = {
+      requires_scan = requires_scan,
+      providers = {},
+      recommended_path = requires_scan and '/checkv2' or '/controller/neural/learn_message',
+      settings = {},
+    }
+
+    if has_providers then
+      for _, p in ipairs(rule.providers) do
+        r.providers[#r.providers + 1] = { type = p.type }
+      end
+    end
+
+    for sid, set in pairs(rule.settings or {}) do
+      if type(set) == 'table' then
+        r.settings[#r.settings + 1] = set.name
+      end
+    end
+
+    out.rules[name] = r
+  end
+
+  conn:send_ucl({ success = true, data = out })
+end
+
+-- Train directly from a message for providers that don't require full /checkv2
+-- Headers:
+--  - ANN-Train or Class: 'spam' | 'ham'
+--  - Rule: rule name (optional, default 'default')
+local function handle_learn_message(task, conn)
+  lua_util.debugm(N, task, 'controller.neural: learn_message called')
+  local cls = task:get_request_header('ANN-Train') or task:get_request_header('Class')
+  if not cls then
+    conn:send_error(400, 'missing class header (ANN-Train or Class)')
+    return
+  end
+
+  local learn_type = tostring(cls):lower()
+  if learn_type ~= 'spam' and learn_type ~= 'ham' then
+    conn:send_error(400, 'unsupported class (expected spam or ham)')
+    return
+  end
+
+  local rule_name = task:get_request_header('Rule') or 'default'
+  local rule = neural_common.settings.rules[rule_name]
+  if not rule then
+    conn:send_error(400, 'unknown rule')
+    return
+  end
+
+  -- If no providers or symbols provider configured, require full scan path
+  local has_providers = type(rule.providers) == 'table' and #rule.providers > 0
+  if not has_providers then
+    lua_util.debugm(N, task, 'controller.neural: learn_message refused: no providers (assume symbols) for rule=%s',
+      rule_name)
+    conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured)')
+    return
+  end
+  for _, p in ipairs(rule.providers) do
+    if p.type == 'symbols' then
+      lua_util.debugm(N, task, 'controller.neural: learn_message refused due to symbols provider for rule=%s', rule_name)
+      conn:send_error(400, 'rule requires full /checkv2 scan (symbols provider present)')
+      return
+    end
+  end
+
+  local set = neural_common.get_rule_settings(task, rule)
+  if not set or not set.ann or not set.ann.redis_key then
+    conn:send_error(400, 'invalid rule settings for learning')
+    return
+  end
+
+  local function after_collect(vec)
+    lua_util.debugm(N, task, 'controller.neural: learn_message after_collect, vector=%s', type(vec))
+    if not vec then
+      vec = neural_common.result_to_vector(task, set)
+    end
+
+    if type(vec) ~= 'table' then
+      conn:send_error(500, 'failed to build training vector')
+      return
+    end
+
+    local compressed = rspamd_util.zstd_compress(table.concat(vec, ';'))
+    local target_key = string.format('%s_%s_set', set.ann.redis_key, learn_type)
+
+    local function learn_vec_cb(redis_err)
+      if redis_err then
+        rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+          rule.prefix, set.name, redis_err)
+        conn:send_error(500, 'cannot store train vector')
+      else
+        lua_util.debugm(N, task, 'controller.neural: stored train vector for rule=%s key=%s bytes=%s', rule_name,
+          target_key, #compressed)
+        conn:send_ucl({ success = true, stored = #compressed, key = target_key })
+      end
+    end
+
+    lua_redis.redis_make_request(task,
+      rule.redis,
+      nil,
+      true,
+      learn_vec_cb,
+      'SADD',
+      { target_key, compressed }
+    )
+  end
+
+  if rule.providers and #rule.providers > 0 then
+    lua_util.debugm(N, task, 'controller.neural: collecting features for rule=%s', rule_name)
+    neural_common.collect_features_async(task, rule, set, 'train', after_collect)
+  else
+    -- Should not reach here due to early return
+    conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured)')
+  end
+end
+
 local function handle_train(task, conn, req_params)
   local rule_name = req_params.rule or 'default'
   local rule = neural_common.settings.rules[rule_name]
@@ -115,6 +273,16 @@ return {
     enable = true,
     need_task = true,
   },
+  config = {
+    handler = handle_config,
+    enable = true,
+    need_task = false,
+  },
+  learn_message = {
+    handler = handle_learn_message,
+    enable = true,
+    need_task = true,
+  },
   status = {
     handler = handle_status,
     enable = false,
index 1dc48faaec551cdd8a4f78269ae941c6941c1094..e2128f357aecb416f5bd4cf2ce73f21cdfca112f 100644 (file)
@@ -32,6 +32,7 @@
 #include <cstdio>
 #include <cmath>
 #include <locale>
+#include <unordered_map>
 
 #include "frozen/string.h"
 #include "frozen/unordered_map.h"
@@ -59,7 +60,7 @@ static const char *user = nullptr;
 static const char *helo = nullptr;
 static const char *hostname = nullptr;
 static const char *classifier = nullptr;
-static const char *learn_class_name = nullptr;
+static std::string learn_class_name;
 static const char *local_addr = nullptr;
 static const char *execute = nullptr;
 static const char *sort = nullptr;
@@ -94,6 +95,10 @@ static const char *files_list = nullptr;
 static const char *queue_id = nullptr;
 static const char *log_tag = nullptr;
 static std::string settings;
+static std::string neural_train;
+static std::string neural_rule;
+static bool neural_cfg_loaded = false;
+static std::unordered_map<std::string, bool> neural_rule_requires_scan;
 
 std::vector<GPid> children;
 static GPatternSpec **exclude_compiled = nullptr;
@@ -214,6 +219,7 @@ enum rspamc_command_type {
        RSPAMC_COMMAND_LEARN_SPAM,
        RSPAMC_COMMAND_LEARN_HAM,
        RSPAMC_COMMAND_LEARN_CLASS,
+       RSPAMC_COMMAND_NEURAL_LEARN,
        RSPAMC_COMMAND_FUZZY_ADD,
        RSPAMC_COMMAND_FUZZY_DEL,
        RSPAMC_COMMAND_FUZZY_DELHASH,
@@ -241,7 +247,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_SYMBOLS,
                .name = "symbols",
-               .path = "checkv2",
+               .path = "/checkv2",
                .description = "scan message and show symbols (default command)",
                .is_controller = FALSE,
                .is_privileged = FALSE,
@@ -250,7 +256,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_LEARN_SPAM,
                .name = "learn_spam",
-               .path = "learnspam",
+               .path = "/learnspam",
                .description = "learn message as spam",
                .is_controller = TRUE,
                .is_privileged = TRUE,
@@ -259,7 +265,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_LEARN_HAM,
                .name = "learn_ham",
-               .path = "learnham",
+               .path = "/learnham",
                .description = "learn message as ham",
                .is_controller = TRUE,
                .is_privileged = TRUE,
@@ -268,16 +274,25 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_LEARN_CLASS,
                .name = "learn_class",
-               .path = "learnclass",
+               .path = "/learnclass",
                .description = "learn message as class",
                .is_controller = TRUE,
                .is_privileged = TRUE,
                .need_input = TRUE,
                .command_output_func = nullptr},
+       rspamc_command{
+               .cmd = RSPAMC_COMMAND_NEURAL_LEARN,
+               .name = "neural_learn",
+               .path = "/checkv2",
+               .description = "learn neural with a class (use neural_learn:<class>)",
+               .is_controller = FALSE,
+               .is_privileged = FALSE,
+               .need_input = TRUE,
+               .command_output_func = rspamc_symbols_output},
        rspamc_command{
                .cmd = RSPAMC_COMMAND_FUZZY_ADD,
                .name = "fuzzy_add",
-               .path = "fuzzyadd",
+               .path = "/fuzzyadd",
                .description =
                        "add hashes from a message to the fuzzy storage (check -f and -w options for this command)",
                .is_controller = TRUE,
@@ -287,7 +302,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_FUZZY_DEL,
                .name = "fuzzy_del",
-               .path = "fuzzydel",
+               .path = "/fuzzydel",
                .description =
                        "delete hashes from a message from the fuzzy storage (check -f option for this command)",
                .is_controller = TRUE,
@@ -297,7 +312,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_FUZZY_DELHASH,
                .name = "fuzzy_delhash",
-               .path = "fuzzydelhash",
+               .path = "/fuzzydelhash",
                .description =
                        "delete a hash from fuzzy storage (check -f option for this command)",
                .is_controller = TRUE,
@@ -307,7 +322,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_STAT,
                .name = "stat",
-               .path = "stat",
+               .path = "/stat",
                .description = "show rspamd statistics",
                .is_controller = TRUE,
                .is_privileged = FALSE,
@@ -317,7 +332,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_STAT_RESET,
                .name = "stat_reset",
-               .path = "statreset",
+               .path = "/statreset",
                .description = "show and reset rspamd statistics (useful for graphs)",
                .is_controller = TRUE,
                .is_privileged = TRUE,
@@ -326,7 +341,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_COUNTERS,
                .name = "counters",
-               .path = "counters",
+               .path = "/counters",
                .description = "display rspamd symbols statistics",
                .is_controller = TRUE,
                .is_privileged = FALSE,
@@ -335,7 +350,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_UPTIME,
                .name = "uptime",
-               .path = "auth",
+               .path = "/auth",
                .description = "show rspamd uptime",
                .is_controller = TRUE,
                .is_privileged = FALSE,
@@ -344,7 +359,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_ADD_SYMBOL,
                .name = "add_symbol",
-               .path = "addsymbol",
+               .path = "/addsymbol",
                .description = "add or modify symbol settings in rspamd",
                .is_controller = TRUE,
                .is_privileged = TRUE,
@@ -353,7 +368,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
        rspamc_command{
                .cmd = RSPAMC_COMMAND_ADD_ACTION,
                .name = "add_action",
-               .path = "addaction",
+               .path = "/addaction",
                .description = "add or modify action settings",
                .is_controller = TRUE,
                .is_privileged = TRUE,
@@ -714,6 +729,7 @@ check_rspamc_command(const char *cmd) -> std::optional<rspamc_command>
                {"learn_spam", RSPAMC_COMMAND_LEARN_SPAM},
                {"learn_ham", RSPAMC_COMMAND_LEARN_HAM},
                {"learn_class", RSPAMC_COMMAND_LEARN_CLASS},
+               {"neural_learn", RSPAMC_COMMAND_NEURAL_LEARN},
                {"fuzzy_add", RSPAMC_COMMAND_FUZZY_ADD},
                {"fuzzy_del", RSPAMC_COMMAND_FUZZY_DEL},
                {"fuzzy_delhash", RSPAMC_COMMAND_FUZZY_DELHASH},
@@ -725,22 +741,34 @@ check_rspamc_command(const char *cmd) -> std::optional<rspamc_command>
 
        std::string cmd_lc = rspamd_string_tolower(cmd);
 
-       // Handle learn_class:classname syntax
-       if (cmd_lc.find("learn_class:") == 0) {
+       /* Handle colon-suffixed commands in a unified way */
+       {
                auto colon_pos = cmd_lc.find(':');
-               if (colon_pos != std::string::npos && colon_pos + 1 < cmd_lc.length()) {
-                       auto class_name = cmd_lc.substr(colon_pos + 1);
-                       // Store class name globally for later use
-                       learn_class_name = g_strdup(class_name.c_str());
-                       // Return the learn_class command
-                       auto elt_it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&](const auto &item) {
-                               return item.cmd == RSPAMC_COMMAND_LEARN_CLASS;
-                       });
-                       if (elt_it != std::end(rspamc_commands)) {
-                               return *elt_it;
+               if (colon_pos != std::string::npos) {
+                       auto base = cmd_lc.substr(0, colon_pos);
+                       auto arg = cmd_lc.substr(colon_pos + 1);
+
+                       if (!arg.empty()) {
+                               auto find_cmd = [](enum rspamc_command_type t) -> std::optional<rspamc_command> {
+                                       const auto it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&t](const auto &item) {
+                                               return item.cmd == t;
+                                       });
+                                       if (it != std::end(rspamc_commands)) {
+                                               return *it;
+                                       }
+                                       return std::nullopt;
+                               };
+
+                               if (base == "learn_class") {
+                                       learn_class_name = arg;
+                                       return find_cmd(RSPAMC_COMMAND_LEARN_CLASS);
+                               }
+                               else if (base == "neural_learn") {
+                                       neural_train = arg; /* allow any class name, plugin validates */
+                                       return find_cmd(RSPAMC_COMMAND_NEURAL_LEARN);
+                               }
                        }
                }
-               return std::nullopt;
        }
 
        auto ct = rspamd::find_map(str_map, std::string_view{cmd_lc});
@@ -887,8 +915,12 @@ add_options(GQueue *opts)
                add_client_header(opts, "Classifier", classifier);
        }
 
-       if (learn_class_name) {
-               add_client_header(opts, "Class", learn_class_name);
+       if (!learn_class_name.empty()) {
+               add_client_header(opts, "Class", learn_class_name.c_str());
+       }
+
+       if (!neural_train.empty()) {
+               add_client_header(opts, "ANN-Train", neural_train.c_str());
        }
 
        if (weight != 0) {
@@ -939,6 +971,15 @@ add_options(GQueue *opts)
                        add_client_header(opts,
                                                          hdr_view.substr(0, std::distance(std::begin(hdr_view), delim_pos)),
                                                          hdr_view.substr(std::distance(std::begin(hdr_view), delim_pos) + 1));
+                       /* Capture Rule header for neural selection */
+                       if (neural_rule.empty()) {
+                               std::string name_copy{hdr_view.substr(0, std::distance(std::begin(hdr_view), delim_pos))};
+                               std::transform(name_copy.begin(), name_copy.end(), name_copy.begin(), [](unsigned char c) { return std::tolower(c); });
+                               if (name_copy == "rule") {
+                                       auto value_view = hdr_view.substr(std::distance(std::begin(hdr_view), delim_pos) + 1);
+                                       neural_rule.assign(value_view.begin(), value_view.end());
+                               }
+                       }
                }
 
                hdr++;
@@ -2049,6 +2090,7 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd,
        uint16_t port;
        GError *err = nullptr;
        std::string hostbuf;
+       std::string path_override;
 
        if (connect_str[0] == '[') {
                p = strrchr(connect_str, ']');
@@ -2096,6 +2138,22 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd,
                }
        }
 
+       /* Dynamic path/port override for neural_learn based on fetched config */
+       if (cmd.cmd == RSPAMC_COMMAND_NEURAL_LEARN && neural_cfg_loaded) {
+               const auto &rule = !neural_rule.empty() ? neural_rule : std::string{"default"};
+               auto it = neural_rule_requires_scan.find(rule);
+               bool requires_scan = true;
+               if (it != neural_rule_requires_scan.end()) {
+                       requires_scan = it->second;
+               }
+               if (!requires_scan) {
+                       path_override = "/plugins/neural/learn_message";
+                       if (p == nullptr) {
+                               port = DEFAULT_CONTROL_PORT;
+                       }
+               }
+       }
+
        conn = rspamd_client_init(http_ctx, ev_base, hostbuf.c_str(), port, timeout, pubkey);
 
        if (conn != nullptr) {
@@ -2104,12 +2162,14 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd,
                cbdata->filename = name;
 
                if (cmd.need_input) {
-                       rspamd_client_command(conn, cmd.path, attrs, in, rspamc_client_cb,
+                       const char *path = path_override.empty() ? cmd.path : path_override.c_str();
+                       rspamd_client_command(conn, path, attrs, in, rspamc_client_cb,
                                                                  cbdata, compressed, dictionary, cbdata->filename.c_str(), &err);
                }
                else {
+                       const char *path = path_override.empty() ? cmd.path : path_override.c_str();
                        rspamd_client_command(conn,
-                                                                 cmd.path,
+                                                                 path,
                                                                  attrs,
                                                                  nullptr,
                                                                  rspamc_client_cb,
@@ -2281,6 +2341,102 @@ rspamc_kwattr_free(gpointer p)
        g_free(h);
 }
 
+/* Fetch /controller/neural/config once per run and populate neural_cfg_loaded
+ * and neural_rule_requires_scan map. */
+static void rspamc_neural_config_cb(struct rspamd_client_connection *conn,
+                                                                       struct rspamd_http_message *msg,
+                                                                       const char *name, ucl_object_t *result, GString *input,
+                                                                       gpointer ud, double start_time, double send_time,
+                                                                       const char *body, gsize bodylen,
+                                                                       GError *err)
+{
+       /* Populate map: data.rules[rule].requires_scan */
+       if (result != nullptr) {
+               const auto *data = ucl_object_lookup(result, "data");
+               if (data && ucl_object_type(data) == UCL_OBJECT) {
+                       const auto *rules = ucl_object_lookup(data, "rules");
+                       if (rules && ucl_object_type(rules) == UCL_OBJECT) {
+                               ucl_object_iter_t it = nullptr;
+                               const ucl_object_t *cur;
+                               while ((cur = ucl_object_iterate(rules, &it, true)) != nullptr) {
+                                       std::string rule_name = ucl_object_key(cur) ? ucl_object_key(cur) : "";
+                                       auto requires_scan = true;
+                                       const auto *rq = ucl_object_lookup(cur, "requires_scan");
+                                       if (rq) {
+                                               requires_scan = ucl_object_toboolean(rq);
+                                       }
+                                       neural_rule_requires_scan[rule_name] = requires_scan;
+                               }
+                               neural_cfg_loaded = true;
+                       }
+               }
+               ucl_object_unref(result);
+       }
+       else if (err) {
+               /* Do not fail the whole run if config not available */
+               neural_cfg_loaded = false;
+       }
+
+       rspamd_client_destroy(conn);
+}
+
+static void
+rspamc_fetch_neural_config(struct ev_loop *ev_base, GQueue *attrs)
+{
+       /* Build connection to controller port */
+       const char *p;
+       uint16_t port;
+       std::string hostbuf;
+       GError *err = nullptr;
+
+       if (connect_str[0] == '[') {
+               p = strrchr(connect_str, ']');
+               if (p != nullptr) {
+                       hostbuf.assign(connect_str + 1, (std::size_t) (p - connect_str - 1));
+                       p++;
+               }
+               else {
+                       p = connect_str;
+               }
+       }
+       else {
+               p = connect_str;
+       }
+
+       p = strrchr(p, ':');
+       if (hostbuf.empty()) {
+               if (p != nullptr) {
+                       hostbuf.assign(connect_str, (std::size_t) (p - connect_str));
+               }
+               else {
+                       hostbuf.assign(connect_str);
+               }
+       }
+
+       if (p != nullptr) {
+               port = strtoul(p + 1, nullptr, 10);
+       }
+       else {
+               /* Default to controller port if not specified */
+               port = DEFAULT_CONTROL_PORT;
+       }
+
+       auto *conn = rspamd_client_init(http_ctx, ev_base, hostbuf.c_str(), port, timeout, pubkey);
+       if (conn != nullptr) {
+               /* Minimal headers; reuse attrs so users can pass Password, etc. */
+               rspamd_client_command(conn,
+                                                         "/plugins/neural/config",
+                                                         attrs,
+                                                         nullptr,
+                                                         rspamc_neural_config_cb,
+                                                         nullptr,
+                                                         compressed,
+                                                         dictionary,
+                                                         "neural_config",
+                                                         &err);
+       }
+}
+
 int main(int argc, char **argv, char **env)
 {
        auto *kwattrs = g_queue_new();
@@ -2398,6 +2554,13 @@ int main(int argc, char **argv, char **env)
        add_options(kwattrs);
        auto cmd = maybe_cmd.value();
 
+       /* Preload neural config once if we are going to use neural_learn */
+       if (cmd.cmd == RSPAMC_COMMAND_NEURAL_LEARN && !neural_cfg_loaded) {
+               rspamc_fetch_neural_config(event_loop, kwattrs);
+               /* Drive loop once to complete config request */
+               ev_loop(event_loop, 0);
+       }
+
        if (start_argc == argc && files_list == nullptr) {
                /* Do command without input or with stdin */
                if (empty_input) {
index 24240d3c25b8b20a1217c2b52b5cc0462856c93a..adffcf43f091ae651330dc82530a72a00cf4b9c1 100644 (file)
@@ -461,8 +461,14 @@ rspamd_client_command(struct rspamd_client_connection *conn,
         */
        rspamd_http_message_add_header(req->msg, "Accept", "application/msgpack");
 
-       req->msg->url = rspamd_fstring_append(req->msg->url, "/", 1);
-       req->msg->url = rspamd_fstring_append(req->msg->url, command, strlen(command));
+       /* Append path ensuring a single leading slash */
+       if (command != NULL && command[0] == '/') {
+               req->msg->url = rspamd_fstring_append(req->msg->url, command, strlen(command));
+       }
+       else {
+               req->msg->url = rspamd_fstring_append(req->msg->url, "/", 1);
+               req->msg->url = rspamd_fstring_append(req->msg->url, command ? command : "", command ? strlen(command) : 0);
+       }
 
        conn->req = req;
        conn->start_time = rspamd_get_ticks(FALSE);
index 633a45854a539fd9d88ef359f9e40cedc66bfba1..1e8a135f18a544fa6da2f3e13cd056e44466d2b2 100644 (file)
@@ -69,20 +69,20 @@ local function new_ann_profile(task, rule, set, version)
   local function add_cb(err, _)
     if err then
       rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
-        rule.prefix, set.name, profile.redis_key, err)
+          rule.prefix, set.name, profile.redis_key, err)
     else
       rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
-        rule.prefix, set.name, profile.redis_key)
+          rule.prefix, set.name, profile.redis_key)
     end
   end
 
   lua_redis.redis_make_request(task,
-    rule.redis,
-    nil,
-    true,   -- is write
-    add_cb, --callback
-    'ZADD', -- command
-    { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+      rule.redis,
+      nil,
+      true, -- is write
+      add_cb, --callback
+      'ZADD', -- command
+      { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
   )
 
   return profile
@@ -103,18 +103,18 @@ local function ann_scores_filter(task)
         profile = set.ann
       else
         lua_util.debugm(N, task, 'no ann loaded for %s:%s',
-          rule.prefix, set.name)
+            rule.prefix, set.name)
       end
     else
       lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
-        rule.prefix, sid)
+          rule.prefix, sid)
     end
 
     if ann then
       local function after_features(vec, meta)
         if profile.providers_digest and meta and meta.digest and profile.providers_digest ~= meta.digest then
           lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply',
-            rule.prefix, set.name)
+              rule.prefix, set.name)
           vec = nil
         end
 
@@ -131,7 +131,7 @@ local function ann_scores_filter(task)
         local symscore = string.format('%.3f', score)
         task:cache_set(rule.prefix .. '_neural_score', score)
         lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
-          rule.prefix, set.name, set.ann.version, symscore)
+            rule.prefix, set.name, set.ann.version, symscore)
 
         if score > 0 then
           local result = score
@@ -152,8 +152,8 @@ local function ann_scores_filter(task)
             end
           else
             lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
-              rule.prefix, set.name, set.ann.version, symscore,
-              spam_threshold)
+                rule.prefix, set.name, set.ann.version, symscore,
+                spam_threshold)
           end
         else
           local result = -(score)
@@ -174,8 +174,8 @@ local function ann_scores_filter(task)
             end
           else
             lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
-              rule.prefix, set.name, set.ann.version, result,
-              ham_threshold)
+                rule.prefix, set.name, set.ann.version, result,
+                ham_threshold)
           end
         end
       end
@@ -194,21 +194,40 @@ local function ann_push_task_result(rule, task, verdict, score, set)
   local train_opts = rule.train
   local learn_spam, learn_ham
   local skip_reason = 'unknown'
+  local manual_train = false
+
+  -- First, honor explicit manual training header if present
+  do
+    local hdr = task:get_request_header('ANN-Train')
+    if hdr then
+      local hv = tostring(hdr):lower()
+      lua_util.debugm(N, task, 'found ANN-Train header, enable manual train mode', hv)
+      if hv == 'spam' then
+        learn_spam = true
+        manual_train = true
+      elseif hv == 'ham' then
+        learn_ham = true
+        manual_train = true
+      else
+        skip_reason = 'no explicit header'
+      end
+    end
+  end
 
-  if not train_opts.store_pool_only and train_opts.autotrain then
+  if not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then
     if train_opts.spam_score then
       learn_spam = score >= train_opts.spam_score
 
       if not learn_spam then
         skip_reason = string.format('score < spam_score: %f < %f',
-          score, train_opts.spam_score)
+            score, train_opts.spam_score)
       end
     else
       learn_spam = verdict == 'spam' or verdict == 'junk'
 
       if not learn_spam then
         skip_reason = string.format('verdict: %s',
-          verdict)
+            verdict)
       end
     end
 
@@ -216,29 +235,18 @@ local function ann_push_task_result(rule, task, verdict, score, set)
       learn_ham = score <= train_opts.ham_score
       if not learn_ham then
         skip_reason = string.format('score > ham_score: %f > %f',
-          score, train_opts.ham_score)
+            score, train_opts.ham_score)
       end
     else
       learn_ham = verdict == 'ham'
 
       if not learn_ham then
         skip_reason = string.format('verdict: %s',
-          verdict)
+            verdict)
       end
     end
   else
-    -- Train by request header
-    local hdr = task:get_request_header('ANN-Train')
-
-    if hdr then
-      if hdr:lower() == 'spam' then
-        learn_spam = true
-      elseif hdr:lower() == 'ham' then
-        learn_ham = true
-      else
-        skip_reason = 'no explicit header'
-      end
-    elseif train_opts.store_pool_only then
+    if train_opts.store_pool_only and not manual_train then
       local ucl = require "ucl"
       learn_ham = false
       learn_spam = false
@@ -272,7 +280,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
       if not err and type(data) == 'table' then
         local nspam, nham = data[1], data[2]
 
-        if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
+        if manual_train or neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
           local vec
           if rule.providers and #rule.providers > 0 then
             -- Note: this training path remains sync for now; vectors are pushed when computed
@@ -288,41 +296,41 @@ local function ann_push_task_result(rule, task, verdict, score, set)
           local function learn_vec_cb(redis_err)
             if redis_err then
               rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
-                rule.prefix, set.name, redis_err)
+                  rule.prefix, set.name, redis_err)
             else
               lua_util.debugm(N, task,
-                "add train data for ANN rule " ..
-                "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
-                rule.prefix, set.name, learn_type, #vec, target_key, #str)
+                  "add train data for ANN rule " ..
+                      "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+                  rule.prefix, set.name, learn_type, #vec, target_key, #str)
             end
           end
 
           lua_redis.redis_make_request(task,
-            rule.redis,
-            nil,
-            true,               -- is write
-            learn_vec_cb,       --callback
-            'SADD',             -- command
-            { target_key, str } -- arguments
+              rule.redis,
+              nil,
+              true, -- is write
+              learn_vec_cb, --callback
+              'SADD', -- command
+              { target_key, str } -- arguments
           )
         else
           lua_util.debugm(N, task,
-            "do not add %s train data for ANN rule " ..
-            "%s:%s",
-            learn_type, rule.prefix, set.name)
+              "do not add %s train data for ANN rule " ..
+                  "%s:%s",
+              learn_type, rule.prefix, set.name)
         end
       else
         if err then
           rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
-            rule.prefix, set.name, err)
+              rule.prefix, set.name, err)
         elseif type(data) == 'string' then
           -- nil return value
           rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
-            learn_type, rule.prefix, set.name, set.ann.redis_key, data)
+              learn_type, rule.prefix, set.name, set.ann.redis_key, data)
         else
           rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
-            'please remove this key from Redis manually if you perform upgrade from the previous version',
-            rule.prefix, set.name, set.ann.redis_key, type(data))
+              'please remove this key from Redis manually if you perform upgrade from the previous version',
+              rule.prefix, set.name, set.ann.redis_key, type(data))
         end
       end
     end
@@ -333,25 +341,25 @@ local function ann_push_task_result(rule, task, verdict, score, set)
         -- Need to create or load a profile corresponding to the current configuration
         set.ann = new_ann_profile(task, rule, set, 0)
         lua_util.debugm(N, task,
-          'requested new profile for %s, set.ann is missing',
-          set.name)
+            'requested new profile for %s, set.ann is missing',
+            set.name)
       end
 
       lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
-        { task = task, is_write = false },
-        vectors_len_cb,
-        {
-          set.ann.redis_key,
-        })
+          { task = task, is_write = false },
+          vectors_len_cb,
+          {
+            set.ann.redis_key,
+          })
     else
       lua_util.debugm(N, task,
-        'do not push data: train condition not satisfied; reason: not checked existing ANNs')
+          'do not push data: train condition not satisfied; reason: not checked existing ANNs')
     end
   else
     lua_util.debugm(N, task,
-      'do not push data to key %s: train condition not satisfied; reason: %s',
-      (set.ann or {}).redis_key,
-      skip_reason)
+        'do not push data to key %s: train condition not satisfied; reason: %s',
+        (set.ann or {}).redis_key,
+        skip_reason)
   end
 end
 
@@ -376,16 +384,16 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   local function redis_ham_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
-        ann_key, err)
+          ann_key, err)
       -- Unlock on error
       lua_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        true,                                            -- is write
-        neural_common.gen_unlock_cb(rule, set, ann_key), --callback
-        'HDEL',                                          -- command
-        { ann_key, 'lock' }
+          rspamd_config,
+          rule.redis,
+          nil,
+          true, -- is write
+          neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+          'HDEL', -- command
+          { ann_key, 'lock' }
       )
     else
       -- Decompress and convert to numbers each training vector
@@ -406,29 +414,29 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   local function redis_spam_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
-        ann_key, err)
+          ann_key, err)
       -- Unlock ANN on error
       lua_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        true,                                            -- is write
-        neural_common.gen_unlock_cb(rule, set, ann_key), --callback
-        'HDEL',                                          -- command
-        { ann_key, 'lock' }
+          rspamd_config,
+          rule.redis,
+          nil,
+          true, -- is write
+          neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+          'HDEL', -- command
+          { ann_key, 'lock' }
       )
     else
       -- Decompress and convert to numbers each training vector
       spam_elts = process_training_vectors(data)
       -- Now get ham vectors...
       lua_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        false,        -- is write
-        redis_ham_cb, --callback
-        'SMEMBERS',   -- command
-        { ann_key .. '_ham_set' }
+          rspamd_config,
+          rule.redis,
+          nil,
+          false, -- is write
+          redis_ham_cb, --callback
+          'SMEMBERS', -- command
+          { ann_key .. '_ham_set' }
       )
     end
   end
@@ -436,33 +444,33 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   local function redis_lock_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
-        ann_key, err)
+          ann_key, err)
     elseif type(data) == 'number' and data == 1 then
       -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
       lua_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        false,         -- is write
-        redis_spam_cb, --callback
-        'SMEMBERS',    -- command
-        { ann_key .. '_spam_set' }
+          rspamd_config,
+          rule.redis,
+          nil,
+          false, -- is write
+          redis_spam_cb, --callback
+          'SMEMBERS', -- command
+          { ann_key .. '_spam_set' }
       )
 
       rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
-        rule.prefix, set.name, ann_key)
+          rule.prefix, set.name, ann_key)
     else
       local lock_tm = tonumber(data[1])
       rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
-        'locked by another host %s at %s', rule.prefix, set.name, ann_key,
-        data[2], os.date('%c', lock_tm))
+          'locked by another host %s at %s', rule.prefix, set.name, ann_key,
+          data[2], os.date('%c', lock_tm))
     end
   end
 
   -- Check if we are already learning this network
   if set.learning_spawned then
     rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
-      ann_key)
+        ann_key)
     return
   end
 
@@ -470,14 +478,14 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
   -- ANN is locked by another host (or a process, meh)
   lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
-    { ev_base = ev_base, is_write = true },
-    redis_lock_cb,
-    {
-      ann_key,
-      tostring(os.time()),
-      tostring(math.max(10.0, rule.watch_interval * 2)),
-      rspamd_util.get_hostname()
-    })
+      { ev_base = ev_base, is_write = true },
+      redis_lock_cb,
+      {
+        ann_key,
+        tostring(os.time()),
+        tostring(math.max(10.0, rule.watch_interval * 2)),
+        rspamd_util.get_hostname()
+      })
 end
 
 -- This function loads new ann from Redis
@@ -492,7 +500,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
   local function data_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
-        ann_key, err)
+          ann_key, err)
     else
       if type(data) == 'table' then
         if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
@@ -501,7 +509,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
 
           if _err or not ann_data then
             rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
-              rule.prefix .. ':' .. set.name, ann_key, _err)
+                rule.prefix .. ':' .. set.name, ann_key, _err)
             return
           else
             ann = rspamd_kann.load(ann_data)
@@ -525,26 +533,26 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
               end
               -- Also update rank for the loaded ANN to avoid removal
               lua_redis.redis_make_request_taskless(ev_base,
-                rspamd_config,
-                rule.redis,
-                nil,
-                true,    -- is write
-                rank_cb, --callback
-                'ZADD',  -- command
-                { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+                  rspamd_config,
+                  rule.redis,
+                  nil,
+                  true, -- is write
+                  rank_cb, --callback
+                  'ZADD', -- command
+                  { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
               )
               rspamd_logger.infox(rspamd_config,
-                'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
-                rule.prefix, set.name, ann_key, #data[1], profile.version)
+                  'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
+                  rule.prefix, set.name, ann_key, #data[1], profile.version)
             else
               rspamd_logger.errx(rspamd_config,
-                'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
-                rule.prefix, set.name, ann_key)
+                  'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
+                  rule.prefix, set.name, ann_key)
             end
           end
         else
           lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
-            rule.prefix, set.name, ann_key)
+              rule.prefix, set.name, ann_key)
         end
 
         if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
@@ -556,8 +564,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
             local roc_thresholds = parser:get_object()
             set.ann.roc_thresholds = roc_thresholds
             rspamd_logger.infox(rspamd_config,
-              'loaded ROC thresholds for %s:%s; version=%s',
-              rule.prefix, set.name, profile.version)
+                'loaded ROC thresholds for %s:%s; version=%s',
+                rule.prefix, set.name, profile.version)
             rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
           end
         end
@@ -570,19 +578,19 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
               -- We can use PCA
               set.ann.pca = rspamd_tensor.load(pca_data)
               rspamd_logger.infox(rspamd_config,
-                'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
-                rule.prefix, set.name, ann_key, #data[3], profile.version)
+                  'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
+                  rule.prefix, set.name, ann_key, #data[3], profile.version)
             else
               -- no need in pca, why is it there?
               rspamd_logger.warnx(rspamd_config,
-                'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
-                rule.prefix, set.name, ann_key)
+                  'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
+                  rule.prefix, set.name, ann_key)
             end
           else
             -- pca can be missing merely if we have no max_inputs
             if rule.max_inputs then
               rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
-                rule.prefix, set.name, ann_key, _err)
+                  rule.prefix, set.name, ann_key, _err)
               set.ann.ann = nil
             else
               -- It is okay
@@ -611,19 +619,19 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
         end
       else
         lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
-          rule.prefix, set.name, ann_key)
+            rule.prefix, set.name, ann_key)
       end
     end
   end
   lua_redis.redis_make_request_taskless(ev_base,
-    rspamd_config,
-    rule.redis,
-    nil,
-    false,                                                                       -- is write
-    data_cb,                                                                     --callback
-    'HMGET',                                                                     -- command
-    { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments
-    { opaque_data = true }
+      rspamd_config,
+      rule.redis,
+      nil,
+      false, -- is write
+      data_cb, --callback
+      'HMGET', -- command
+      { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments
+      { opaque_data = true }
   )
 end
 
@@ -659,34 +667,34 @@ local function process_existing_ann(_, ev_base, rule, set, profiles)
         if set.ann.version < sel_elt.version then
           -- Load new ann
           rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
-            'our version = %s, remote version = %s',
-            rule.prefix .. ':' .. set.name,
-            set.ann.version,
-            sel_elt.version)
+              'our version = %s, remote version = %s',
+              rule.prefix .. ':' .. set.name,
+              set.ann.version,
+              sel_elt.version)
           load_new_ann(rule, ev_base, set, sel_elt, min_diff)
         else
           lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
-            'our version = %s, remote version = %s',
-            rule.prefix .. ':' .. set.name,
-            set.ann.version,
-            sel_elt.version)
+              'our version = %s, remote version = %s',
+              rule.prefix .. ':' .. set.name,
+              set.ann.version,
+              sel_elt.version)
         end
       else
         -- We have some different ANN, so we need to compare distance
         if set.ann.distance > min_diff then
           -- Load more specific ANN
           rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
-            'our distance = %s, remote distance = %s',
-            rule.prefix .. ':' .. set.name,
-            set.ann.distance,
-            min_diff)
+              'our distance = %s, remote distance = %s',
+              rule.prefix .. ':' .. set.name,
+              set.ann.distance,
+              min_diff)
           load_new_ann(rule, ev_base, set, sel_elt, min_diff)
         else
           lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
-            'our distance = %s, remote distance = %s',
-            rule.prefix .. ':' .. set.name,
-            set.ann.distance,
-            min_diff)
+              'our distance = %s, remote distance = %s',
+              rule.prefix .. ':' .. set.name,
+              set.ann.distance,
+              min_diff)
         end
       end
     else
@@ -724,14 +732,14 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
     local ann_key = sel_elt.redis_key
 
     lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
-      ann_key)
+        ann_key)
 
     -- Create continuation closure
     local redis_len_cb_gen = function(cont_cb, what, is_final)
       return function(err, data)
         if err then
           rspamd_logger.errx(rspamd_config,
-            'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
+              'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
         elseif data and type(data) == 'number' or type(data) == 'string' then
           local ntrains = tonumber(data) or 0
           lens[what] = ntrains
@@ -752,31 +760,31 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
               end
               if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
                 lua_util.debugm(N, rspamd_config,
-                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                  ann_key, lens, rule.train.max_trains, what)
+                    'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+                    ann_key, lens, rule.train.max_trains, what)
                 cont_cb()
               else
                 lua_util.debugm(N, rspamd_config,
-                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                  ann_key, what, lens, rule.train.max_trains)
+                    'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+                    ann_key, what, lens, rule.train.max_trains)
               end
             else
               -- Probabilistic mode, just ensure that at least one vector is okay
               if min_len > 0 and max_len >= rule.train.max_trains then
                 lua_util.debugm(N, rspamd_config,
-                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                  ann_key, lens, rule.train.max_trains, what)
+                    'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+                    ann_key, lens, rule.train.max_trains, what)
                 cont_cb()
               else
                 lua_util.debugm(N, rspamd_config,
-                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                  ann_key, what, lens, rule.train.max_trains)
+                    'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+                    ann_key, what, lens, rule.train.max_trains)
               end
             end
           else
             lua_util.debugm(N, rspamd_config,
-              'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
-              what, ann_key, ntrains, rule.train.max_trains)
+                'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
+                what, ann_key, ntrains, rule.train.max_trains)
             cont_cb()
           end
         end
@@ -785,32 +793,32 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
 
     local function initiate_train()
       rspamd_logger.infox(rspamd_config,
-        'need to learn ANN %s after %s required learn vectors',
-        ann_key, lens)
+          'need to learn ANN %s after %s required learn vectors',
+          ann_key, lens)
       do_train_ann(worker, ev_base, rule, set, ann_key)
     end
 
     -- Spam vector is OK, check ham vector length
     local function check_ham_len()
       lua_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        false,                                         -- is write
-        redis_len_cb_gen(initiate_train, 'ham', true), --callback
-        'SCARD',                                       -- command
-        { ann_key .. '_ham_set' }
+          rspamd_config,
+          rule.redis,
+          nil,
+          false, -- is write
+          redis_len_cb_gen(initiate_train, 'ham', true), --callback
+          'SCARD', -- command
+          { ann_key .. '_ham_set' }
       )
     end
 
     lua_redis.redis_make_request_taskless(ev_base,
-      rspamd_config,
-      rule.redis,
-      nil,
-      false,                                          -- is write
-      redis_len_cb_gen(check_ham_len, 'spam', false), --callback
-      'SCARD',                                        -- command
-      { ann_key .. '_spam_set' }
+        rspamd_config,
+        rule.redis,
+        nil,
+        false, -- is write
+        redis_len_cb_gen(check_ham_len, 'spam', false), --callback
+        'SCARD', -- command
+        { ann_key .. '_spam_set' }
     )
   end
 end
@@ -823,7 +831,7 @@ local function load_ann_profile(element)
   local res, ucl_err = parser:parse_string(element)
   if not res then
     rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
-      ucl_err)
+        ucl_err)
     return nil
   else
     local profile = parser:get_object()
@@ -843,11 +851,11 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
     local function members_cb(err, data)
       if err then
         rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
-          err)
+            err)
         set.can_store_vectors = true
       elseif type(data) == 'table' then
         lua_util.debugm(N, cfg, '%s: process element %s:%s',
-          what, rule.prefix, set.name)
+            what, rule.prefix, set.name)
         process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
         set.can_store_vectors = true
       end
@@ -859,13 +867,13 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
       -- Select the most appropriate to our profile but it should not differ by more
       -- than 30% of symbols
       lua_redis.redis_make_request_taskless(ev_base,
-        cfg,
-        rule.redis,
-        nil,
-        false,                                               -- is write
-        members_cb,                                          --callback
-        'ZREVRANGE',                                         -- command
-        { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
+          cfg,
+          rule.redis,
+          nil,
+          false, -- is write
+          members_cb, --callback
+          'ZREVRANGE', -- command
+          { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
       )
     end
   end -- Cycle over all settings
@@ -879,23 +887,23 @@ local function cleanup_anns(rule, cfg, ev_base)
     local function invalidate_cb(err, data)
       if err then
         rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
-          err)
+            err)
       elseif type(data) == 'table' then
         for _, expired in ipairs(data) do
           local profile = load_ann_profile(expired)
           rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
-            rule.prefix .. ':' .. set.name,
-            profile.redis_key,
-            profile.version)
+              rule.prefix .. ':' .. set.name,
+              profile.redis_key,
+              profile.version)
         end
       end
     end
 
     if type(set) == 'table' then
       lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
-        { ev_base = ev_base, is_write = true },
-        invalidate_cb,
-        { set.prefix, tostring(settings.max_profiles) })
+          { ev_base = ev_base, is_write = true },
+          invalidate_cb,
+          { set.prefix, tostring(settings.max_profiles) })
     end
   end
 end
@@ -914,14 +922,14 @@ local function ann_push_vector(task)
 
   if verdict == 'passthrough' then
     lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
-      verdict, score)
+        verdict, score)
 
     return
   end
 
   if score ~= score then
     lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
-      verdict)
+        verdict)
 
     return
   end
@@ -991,7 +999,7 @@ for k, r in pairs(rules) do
 
   if rule_elt.max_inputs and not has_blas then
     rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in',
-      rule_elt.name, rule_elt.max_inputs)
+        rule_elt.name, rule_elt.max_inputs)
     rule_elt.max_inputs = nil
   end
 
@@ -1003,7 +1011,7 @@ for k, r in pairs(rules) do
       end
       if (pcfg.type == 'llm' or pcfg.name == 'llm') and not (pcfg.model or (rspamd_config:get_all_opt('gpt') or {}).model) then
         rspamd_logger.errx(rspamd_config,
-          'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
+            'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
       end
     end
   end
@@ -1054,21 +1062,21 @@ for _, rule in pairs(settings.rules) do
   rspamd_config:add_on_load(function(cfg, ev_base, worker)
     if worker:is_scanner() then
       rspamd_config:add_periodic(ev_base, 0.0,
-        function(_, _)
-          return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
-            'try_load_ann')
-        end)
+          function(_, _)
+            return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
+                'try_load_ann')
+          end)
     end
 
     if worker:is_primary_controller() then
       -- We also want to train neural nets when they have enough data
       rspamd_config:add_periodic(ev_base, 0.0,
-        function(_, _)
-          -- Clean old ANNs
-          cleanup_anns(rule, cfg, ev_base)
-          return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
-            'try_train_ann')
-        end)
+          function(_, _)
+            -- Clean old ANNs
+            cleanup_anns(rule, cfg, ev_base)
+            return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
+                'try_train_ann')
+          end)
     end
   end)
 end
index f61998f46ec0f15a93ff8030a4d9193457da8341..1bda5e5c2ece062603438db28ee2c8d084903ad8 100644 (file)
@@ -475,6 +475,15 @@ Run Dummy Https
   Wait Until Created  /tmp/dummy_https.pid  timeout=2 second
   Export Scoped Variables  ${RSPAMD_SCOPE}  DUMMY_HTTPS_PROC=${result}
 
+Run Dummy Llm
+  ${result} =  Start Process  python3  ${RSPAMD_TESTDIR}/util/dummy_llm.py  18080
+  Wait Until Created  /tmp/dummy_llm.pid  timeout=2 second
+  Export Scoped Variables  ${RSPAMD_SCOPE}  DUMMY_LLM_PROC=${result}
+
+Dummy Llm Teardown
+  Terminate Process  ${DUMMY_LLM_PROC}
+  Wait For Process  ${DUMMY_LLM_PROC}
+
 Dummy Http Teardown
   Terminate Process  ${DUMMY_HTTP_PROC}
   Wait For Process  ${DUMMY_HTTP_PROC}