From: Vsevolod Stakhov Date: Thu, 28 Aug 2025 11:32:16 +0000 (+0100) Subject: [Project] Rework rspamc to allow training of different neural types X-Git-Tag: 3.13.0~22^2~3 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=597ae4da82b16bb28652d922ac0180eccd3ff456;p=thirdparty%2Frspamd.git [Project] Rework rspamc to allow training of different neural types --- diff --git a/lualib/plugins/neural/providers/llm.lua b/lualib/plugins/neural/providers/llm.lua index 33301e9084..4f17979c50 100644 --- a/lualib/plugins/neural/providers/llm.lua +++ b/lualib/plugins/neural/providers/llm.lua @@ -14,8 +14,8 @@ local llm_common = require "llm_common" local N = "neural.llm" local function select_text(task) - local content = llm_common.build_llm_input(task) - return content + local input_tbl = llm_common.build_llm_input(task) + return input_tbl end local function compose_llm_settings(pcfg) @@ -72,23 +72,29 @@ neural_common.register_provider('llm', { return nil end - local content = select_text(task) - if not content or #content == 0 then + local input_tbl = select_text(task) + if not input_tbl then rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip') return nil end + -- Build request input string (text then optional subject), keeping rspamd_text intact + local input_string = input_tbl.text or '' + if input_tbl.subject and input_tbl.subject ~= '' then + input_string = input_string .. "\nSubject: " .. input_tbl.subject + end + local body if llm.type == 'openai' then - body = { model = llm.model, input = content } + body = { model = llm.model, input = input_string } elseif llm.type == 'ollama' then - body = { model = llm.model, prompt = content } + body = { model = llm.model, prompt = input_string } else rspamd_logger.debugm(N, task, 'unsupported llm type: %s', llm.type) return nil end - -- Redis cache: use content hash + model + provider as key + -- Redis cache: hash the final input string only (IUF is trivial here) local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, { cache_prefix = llm.cache_prefix, cache_ttl = llm.cache_ttl, @@ -97,9 +103,8 @@ neural_common.register_provider('llm', { cache_use_hashing = llm.cache_use_hashing, }, N) - -- Use a stable key based on content digest local hasher = require 'rspamd_cryptobox_hash' - local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(content):hex()) + local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(input_string):hex()) local function do_request_and_cache() local headers = { ['Content-Type'] = 'application/json' } @@ -119,161 +124,39 @@ neural_common.register_provider('llm', { use_gzip = true, } - local err, data = rspamd_http.request(http_params) - if err then - rspamd_logger.debugm(N, task, 'llm request failed: %s', err) - return nil - end - - local parser = ucl.parser() - local ok, perr = parser:parse_string(data.content) - if not ok then - rspamd_logger.debugm(N, task, 'cannot parse llm response: %s', perr) - return nil - end - - local parsed = parser:get_object() - local embedding = extract_embedding(llm.type, parsed) - if not embedding or #embedding == 0 then - rspamd_logger.debugm(N, task, 'no embedding in llm response') - return nil - end - - for i = 1, #embedding do - embedding[i] = tonumber(embedding[i]) or 0.0 - end - - lua_cache.cache_set(task, key, { e = embedding }, cache_ctx) - return embedding - end + local function http_cb(err, code, resp, _) + if err then + rspamd_logger.debugm(N, task, 'llm http error: %s', err) + return + end + if code ~= 200 or not resp then + rspamd_logger.debugm(N, task, 'llm bad http code: %s', code) + return + end - -- Try cache first - local cached_result - local done = false - lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0, - function(_) - -- Uncached: perform request synchronously and store - cached_result = do_request_and_cache() - done = true - end, - function(_, err, data) - if data and data.e then - cached_result = data.e + local parser = ucl.parser() + local ok, perr = parser:parse_string(resp) + if not ok then + rspamd_logger.debugm(N, task, 'llm cannot parse reply: %s', perr) + return + end + local parsed = parser:get_object() + local emb = extract_embedding(llm.type, parsed) + if type(emb) == 'table' then + cache_ctx:set_cached(key, emb) + neural_common.append_provider_vector(ctx, { provider = 'llm', vector = emb }) end - done = true end - ) - if not done then - -- Fallback: ensure we still do the request now (cache API is async-ready, but we need sync path) - cached_result = do_request_and_cache() + rspamd_http.request(http_params, http_cb) end - local embedding = cached_result - if not embedding then - return nil + local cached = cache_ctx:get_cached(key) + if type(cached) == 'table' then + neural_common.append_provider_vector(ctx, { provider = 'llm', vector = cached }) + return end - local meta = { - name = pcfg.name or 'llm', - type = 'llm', - dim = #embedding, - weight = pcfg.weight or 1.0, - model = llm.model, - provider = llm.type, - } - - return embedding, meta + do_request_and_cache() end, - collect_async = function(task, ctx, cont) - local pcfg = ctx.config or {} - local llm = compose_llm_settings(pcfg) - if not llm.model then - return cont(nil) - end - local content = select_text(task) - if not content or #content == 0 then - return cont(nil) - end - local body - if llm.type == 'openai' then - body = { model = llm.model, input = content } - elseif llm.type == 'ollama' then - body = { model = llm.model, prompt = content } - else - return cont(nil) - end - local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, { - cache_prefix = llm.cache_prefix, - cache_ttl = llm.cache_ttl, - cache_format = 'messagepack', - cache_hash_len = llm.cache_hash_len, - cache_use_hashing = llm.cache_use_hashing, - }, N) - local hasher = require 'rspamd_cryptobox_hash' - local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(content):hex()) - - local function finish_with_embedding(embedding) - if not embedding then return cont(nil) end - for i = 1, #embedding do - embedding[i] = tonumber(embedding[i]) or 0.0 - end - cont(embedding, { - name = pcfg.name or 'llm', - type = 'llm', - dim = #embedding, - weight = pcfg.weight or 1.0, - model = llm.model, - provider = llm.type, - }) - end - - local function request_and_cache() - local headers = { ['Content-Type'] = 'application/json' } - if llm.type == 'openai' and llm.api_key then - headers['Authorization'] = 'Bearer ' .. llm.api_key - end - local http_params = { - url = llm.url, - mime_type = 'application/json', - timeout = llm.timeout, - log_obj = task, - headers = headers, - body = ucl.to_format(body, 'json-compact', true), - task = task, - method = 'POST', - use_gzip = true, - callback = function(err, _, data) - if err then return cont(nil) end - local parser = ucl.parser() - local ok = parser:parse_text(data) - if not ok then return cont(nil) end - local parsed = parser:get_object() - local embedding = extract_embedding(llm.type, parsed) - if embedding and cache_ctx then - lua_cache.cache_set(task, key, { e = embedding }, cache_ctx) - end - finish_with_embedding(embedding) - end, - } - rspamd_http.request(http_params) - end - - if cache_ctx then - lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0, - function(_) - request_and_cache() - end, - function(_, err, data) - if data and data.e then - finish_with_embedding(data.e) - else - request_and_cache() - end - end - ) - else - request_and_cache() - end - end }) diff --git a/rules/controller/neural.lua b/rules/controller/neural.lua index 0a55eebeeb..6e8cd80b58 100644 --- a/rules/controller/neural.lua +++ b/rules/controller/neural.lua @@ -18,8 +18,12 @@ local neural_common = require "plugins/neural" local ts = require("tableshape").types local ucl = require "ucl" local lua_util = require "lua_util" +local rspamd_util = require "rspamd_util" +local lua_redis = require "lua_redis" +local rspamd_logger = require "rspamd_logger" local E = {} +local N = 'neural' -- Controller neural plugin @@ -30,6 +34,7 @@ local learn_request_schema = ts.shape { } local function handle_learn(task, conn) + lua_util.debugm(N, task, 'controller.neural: learn called') local parser = ucl.parser() local ok, err = parser:parse_text(task:get_rawbody()) if not ok then @@ -59,10 +64,12 @@ local function handle_learn(task, conn) worker = task:get_worker(), } + lua_util.debugm(N, task, 'controller.neural: learn scheduled for rule=%s', rule_name) conn:send_string('{"success" : true}') end local function handle_status(task, conn, req_params) + lua_util.debugm(N, task, 'controller.neural: status called') local out = { rules = {}, } @@ -72,7 +79,20 @@ local function handle_status(task, conn, req_params) fusion = rule.fusion, max_inputs = rule.max_inputs, settings = {}, + requires_scan = false, } + -- Default: if no providers configured, assume symbols (full scan required) + local has_providers = type(rule.providers) == 'table' and #rule.providers > 0 + if not has_providers then + r.requires_scan = true + else + for _, p in ipairs(rule.providers) do + if p.type == 'symbols' then + r.requires_scan = true + break + end + end + end for sid, set in pairs(rule.settings or {}) do if type(set) == 'table' then local s = { @@ -95,6 +115,144 @@ local function handle_status(task, conn, req_params) conn:send_ucl({ success = true, data = out }) end +-- Return compact configuration for clients (e.g. rspamc) to plan learning +local function handle_config(task, conn, req_params) + lua_util.debugm(N, task, 'controller.neural: config called') + local out = { + rules = {}, + } + + for name, rule in pairs(neural_common.settings.rules) do + local requires_scan = false + local has_providers = type(rule.providers) == 'table' and #rule.providers > 0 + if not has_providers then + requires_scan = true + else + for _, p in ipairs(rule.providers) do + if p.type == 'symbols' then + requires_scan = true + break + end + end + end + + local r = { + requires_scan = requires_scan, + providers = {}, + recommended_path = requires_scan and '/checkv2' or '/controller/neural/learn_message', + settings = {}, + } + + if has_providers then + for _, p in ipairs(rule.providers) do + r.providers[#r.providers + 1] = { type = p.type } + end + end + + for sid, set in pairs(rule.settings or {}) do + if type(set) == 'table' then + r.settings[#r.settings + 1] = set.name + end + end + + out.rules[name] = r + end + + conn:send_ucl({ success = true, data = out }) +end + +-- Train directly from a message for providers that don't require full /checkv2 +-- Headers: +-- - ANN-Train or Class: 'spam' | 'ham' +-- - Rule: rule name (optional, default 'default') +local function handle_learn_message(task, conn) + lua_util.debugm(N, task, 'controller.neural: learn_message called') + local cls = task:get_request_header('ANN-Train') or task:get_request_header('Class') + if not cls then + conn:send_error(400, 'missing class header (ANN-Train or Class)') + return + end + + local learn_type = tostring(cls):lower() + if learn_type ~= 'spam' and learn_type ~= 'ham' then + conn:send_error(400, 'unsupported class (expected spam or ham)') + return + end + + local rule_name = task:get_request_header('Rule') or 'default' + local rule = neural_common.settings.rules[rule_name] + if not rule then + conn:send_error(400, 'unknown rule') + return + end + + -- If no providers or symbols provider configured, require full scan path + local has_providers = type(rule.providers) == 'table' and #rule.providers > 0 + if not has_providers then + lua_util.debugm(N, task, 'controller.neural: learn_message refused: no providers (assume symbols) for rule=%s', + rule_name) + conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured)') + return + end + for _, p in ipairs(rule.providers) do + if p.type == 'symbols' then + lua_util.debugm(N, task, 'controller.neural: learn_message refused due to symbols provider for rule=%s', rule_name) + conn:send_error(400, 'rule requires full /checkv2 scan (symbols provider present)') + return + end + end + + local set = neural_common.get_rule_settings(task, rule) + if not set or not set.ann or not set.ann.redis_key then + conn:send_error(400, 'invalid rule settings for learning') + return + end + + local function after_collect(vec) + lua_util.debugm(N, task, 'controller.neural: learn_message after_collect, vector=%s', type(vec)) + if not vec then + vec = neural_common.result_to_vector(task, set) + end + + if type(vec) ~= 'table' then + conn:send_error(500, 'failed to build training vector') + return + end + + local compressed = rspamd_util.zstd_compress(table.concat(vec, ';')) + local target_key = string.format('%s_%s_set', set.ann.redis_key, learn_type) + + local function learn_vec_cb(redis_err) + if redis_err then + rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', + rule.prefix, set.name, redis_err) + conn:send_error(500, 'cannot store train vector') + else + lua_util.debugm(N, task, 'controller.neural: stored train vector for rule=%s key=%s bytes=%s', rule_name, + target_key, #compressed) + conn:send_ucl({ success = true, stored = #compressed, key = target_key }) + end + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + true, + learn_vec_cb, + 'SADD', + { target_key, compressed } + ) + end + + if rule.providers and #rule.providers > 0 then + lua_util.debugm(N, task, 'controller.neural: collecting features for rule=%s', rule_name) + neural_common.collect_features_async(task, rule, set, 'train', after_collect) + else + -- Should not reach here due to early return + conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured)') + end +end + local function handle_train(task, conn, req_params) local rule_name = req_params.rule or 'default' local rule = neural_common.settings.rules[rule_name] @@ -115,6 +273,16 @@ return { enable = true, need_task = true, }, + config = { + handler = handle_config, + enable = true, + need_task = false, + }, + learn_message = { + handler = handle_learn_message, + enable = true, + need_task = true, + }, status = { handler = handle_status, enable = false, diff --git a/src/client/rspamc.cxx b/src/client/rspamc.cxx index 1dc48faaec..e2128f357a 100644 --- a/src/client/rspamc.cxx +++ b/src/client/rspamc.cxx @@ -32,6 +32,7 @@ #include #include #include +#include #include "frozen/string.h" #include "frozen/unordered_map.h" @@ -59,7 +60,7 @@ static const char *user = nullptr; static const char *helo = nullptr; static const char *hostname = nullptr; static const char *classifier = nullptr; -static const char *learn_class_name = nullptr; +static std::string learn_class_name; static const char *local_addr = nullptr; static const char *execute = nullptr; static const char *sort = nullptr; @@ -94,6 +95,10 @@ static const char *files_list = nullptr; static const char *queue_id = nullptr; static const char *log_tag = nullptr; static std::string settings; +static std::string neural_train; +static std::string neural_rule; +static bool neural_cfg_loaded = false; +static std::unordered_map neural_rule_requires_scan; std::vector children; static GPatternSpec **exclude_compiled = nullptr; @@ -214,6 +219,7 @@ enum rspamc_command_type { RSPAMC_COMMAND_LEARN_SPAM, RSPAMC_COMMAND_LEARN_HAM, RSPAMC_COMMAND_LEARN_CLASS, + RSPAMC_COMMAND_NEURAL_LEARN, RSPAMC_COMMAND_FUZZY_ADD, RSPAMC_COMMAND_FUZZY_DEL, RSPAMC_COMMAND_FUZZY_DELHASH, @@ -241,7 +247,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_SYMBOLS, .name = "symbols", - .path = "checkv2", + .path = "/checkv2", .description = "scan message and show symbols (default command)", .is_controller = FALSE, .is_privileged = FALSE, @@ -250,7 +256,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_LEARN_SPAM, .name = "learn_spam", - .path = "learnspam", + .path = "/learnspam", .description = "learn message as spam", .is_controller = TRUE, .is_privileged = TRUE, @@ -259,7 +265,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_LEARN_HAM, .name = "learn_ham", - .path = "learnham", + .path = "/learnham", .description = "learn message as ham", .is_controller = TRUE, .is_privileged = TRUE, @@ -268,16 +274,25 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_LEARN_CLASS, .name = "learn_class", - .path = "learnclass", + .path = "/learnclass", .description = "learn message as class", .is_controller = TRUE, .is_privileged = TRUE, .need_input = TRUE, .command_output_func = nullptr}, + rspamc_command{ + .cmd = RSPAMC_COMMAND_NEURAL_LEARN, + .name = "neural_learn", + .path = "/checkv2", + .description = "learn neural with a class (use neural_learn:)", + .is_controller = FALSE, + .is_privileged = FALSE, + .need_input = TRUE, + .command_output_func = rspamc_symbols_output}, rspamc_command{ .cmd = RSPAMC_COMMAND_FUZZY_ADD, .name = "fuzzy_add", - .path = "fuzzyadd", + .path = "/fuzzyadd", .description = "add hashes from a message to the fuzzy storage (check -f and -w options for this command)", .is_controller = TRUE, @@ -287,7 +302,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_FUZZY_DEL, .name = "fuzzy_del", - .path = "fuzzydel", + .path = "/fuzzydel", .description = "delete hashes from a message from the fuzzy storage (check -f option for this command)", .is_controller = TRUE, @@ -297,7 +312,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_FUZZY_DELHASH, .name = "fuzzy_delhash", - .path = "fuzzydelhash", + .path = "/fuzzydelhash", .description = "delete a hash from fuzzy storage (check -f option for this command)", .is_controller = TRUE, @@ -307,7 +322,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_STAT, .name = "stat", - .path = "stat", + .path = "/stat", .description = "show rspamd statistics", .is_controller = TRUE, .is_privileged = FALSE, @@ -317,7 +332,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_STAT_RESET, .name = "stat_reset", - .path = "statreset", + .path = "/statreset", .description = "show and reset rspamd statistics (useful for graphs)", .is_controller = TRUE, .is_privileged = TRUE, @@ -326,7 +341,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_COUNTERS, .name = "counters", - .path = "counters", + .path = "/counters", .description = "display rspamd symbols statistics", .is_controller = TRUE, .is_privileged = FALSE, @@ -335,7 +350,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_UPTIME, .name = "uptime", - .path = "auth", + .path = "/auth", .description = "show rspamd uptime", .is_controller = TRUE, .is_privileged = FALSE, @@ -344,7 +359,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_ADD_SYMBOL, .name = "add_symbol", - .path = "addsymbol", + .path = "/addsymbol", .description = "add or modify symbol settings in rspamd", .is_controller = TRUE, .is_privileged = TRUE, @@ -353,7 +368,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( rspamc_command{ .cmd = RSPAMC_COMMAND_ADD_ACTION, .name = "add_action", - .path = "addaction", + .path = "/addaction", .description = "add or modify action settings", .is_controller = TRUE, .is_privileged = TRUE, @@ -714,6 +729,7 @@ check_rspamc_command(const char *cmd) -> std::optional {"learn_spam", RSPAMC_COMMAND_LEARN_SPAM}, {"learn_ham", RSPAMC_COMMAND_LEARN_HAM}, {"learn_class", RSPAMC_COMMAND_LEARN_CLASS}, + {"neural_learn", RSPAMC_COMMAND_NEURAL_LEARN}, {"fuzzy_add", RSPAMC_COMMAND_FUZZY_ADD}, {"fuzzy_del", RSPAMC_COMMAND_FUZZY_DEL}, {"fuzzy_delhash", RSPAMC_COMMAND_FUZZY_DELHASH}, @@ -725,22 +741,34 @@ check_rspamc_command(const char *cmd) -> std::optional std::string cmd_lc = rspamd_string_tolower(cmd); - // Handle learn_class:classname syntax - if (cmd_lc.find("learn_class:") == 0) { + /* Handle colon-suffixed commands in a unified way */ + { auto colon_pos = cmd_lc.find(':'); - if (colon_pos != std::string::npos && colon_pos + 1 < cmd_lc.length()) { - auto class_name = cmd_lc.substr(colon_pos + 1); - // Store class name globally for later use - learn_class_name = g_strdup(class_name.c_str()); - // Return the learn_class command - auto elt_it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&](const auto &item) { - return item.cmd == RSPAMC_COMMAND_LEARN_CLASS; - }); - if (elt_it != std::end(rspamc_commands)) { - return *elt_it; + if (colon_pos != std::string::npos) { + auto base = cmd_lc.substr(0, colon_pos); + auto arg = cmd_lc.substr(colon_pos + 1); + + if (!arg.empty()) { + auto find_cmd = [](enum rspamc_command_type t) -> std::optional { + const auto it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&t](const auto &item) { + return item.cmd == t; + }); + if (it != std::end(rspamc_commands)) { + return *it; + } + return std::nullopt; + }; + + if (base == "learn_class") { + learn_class_name = arg; + return find_cmd(RSPAMC_COMMAND_LEARN_CLASS); + } + else if (base == "neural_learn") { + neural_train = arg; /* allow any class name, plugin validates */ + return find_cmd(RSPAMC_COMMAND_NEURAL_LEARN); + } } } - return std::nullopt; } auto ct = rspamd::find_map(str_map, std::string_view{cmd_lc}); @@ -887,8 +915,12 @@ add_options(GQueue *opts) add_client_header(opts, "Classifier", classifier); } - if (learn_class_name) { - add_client_header(opts, "Class", learn_class_name); + if (!learn_class_name.empty()) { + add_client_header(opts, "Class", learn_class_name.c_str()); + } + + if (!neural_train.empty()) { + add_client_header(opts, "ANN-Train", neural_train.c_str()); } if (weight != 0) { @@ -939,6 +971,15 @@ add_options(GQueue *opts) add_client_header(opts, hdr_view.substr(0, std::distance(std::begin(hdr_view), delim_pos)), hdr_view.substr(std::distance(std::begin(hdr_view), delim_pos) + 1)); + /* Capture Rule header for neural selection */ + if (neural_rule.empty()) { + std::string name_copy{hdr_view.substr(0, std::distance(std::begin(hdr_view), delim_pos))}; + std::transform(name_copy.begin(), name_copy.end(), name_copy.begin(), [](unsigned char c) { return std::tolower(c); }); + if (name_copy == "rule") { + auto value_view = hdr_view.substr(std::distance(std::begin(hdr_view), delim_pos) + 1); + neural_rule.assign(value_view.begin(), value_view.end()); + } + } } hdr++; @@ -2049,6 +2090,7 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd, uint16_t port; GError *err = nullptr; std::string hostbuf; + std::string path_override; if (connect_str[0] == '[') { p = strrchr(connect_str, ']'); @@ -2096,6 +2138,22 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd, } } + /* Dynamic path/port override for neural_learn based on fetched config */ + if (cmd.cmd == RSPAMC_COMMAND_NEURAL_LEARN && neural_cfg_loaded) { + const auto &rule = !neural_rule.empty() ? neural_rule : std::string{"default"}; + auto it = neural_rule_requires_scan.find(rule); + bool requires_scan = true; + if (it != neural_rule_requires_scan.end()) { + requires_scan = it->second; + } + if (!requires_scan) { + path_override = "/plugins/neural/learn_message"; + if (p == nullptr) { + port = DEFAULT_CONTROL_PORT; + } + } + } + conn = rspamd_client_init(http_ctx, ev_base, hostbuf.c_str(), port, timeout, pubkey); if (conn != nullptr) { @@ -2104,12 +2162,14 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd, cbdata->filename = name; if (cmd.need_input) { - rspamd_client_command(conn, cmd.path, attrs, in, rspamc_client_cb, + const char *path = path_override.empty() ? cmd.path : path_override.c_str(); + rspamd_client_command(conn, path, attrs, in, rspamc_client_cb, cbdata, compressed, dictionary, cbdata->filename.c_str(), &err); } else { + const char *path = path_override.empty() ? cmd.path : path_override.c_str(); rspamd_client_command(conn, - cmd.path, + path, attrs, nullptr, rspamc_client_cb, @@ -2281,6 +2341,102 @@ rspamc_kwattr_free(gpointer p) g_free(h); } +/* Fetch /controller/neural/config once per run and populate neural_cfg_loaded + * and neural_rule_requires_scan map. */ +static void rspamc_neural_config_cb(struct rspamd_client_connection *conn, + struct rspamd_http_message *msg, + const char *name, ucl_object_t *result, GString *input, + gpointer ud, double start_time, double send_time, + const char *body, gsize bodylen, + GError *err) +{ + /* Populate map: data.rules[rule].requires_scan */ + if (result != nullptr) { + const auto *data = ucl_object_lookup(result, "data"); + if (data && ucl_object_type(data) == UCL_OBJECT) { + const auto *rules = ucl_object_lookup(data, "rules"); + if (rules && ucl_object_type(rules) == UCL_OBJECT) { + ucl_object_iter_t it = nullptr; + const ucl_object_t *cur; + while ((cur = ucl_object_iterate(rules, &it, true)) != nullptr) { + std::string rule_name = ucl_object_key(cur) ? ucl_object_key(cur) : ""; + auto requires_scan = true; + const auto *rq = ucl_object_lookup(cur, "requires_scan"); + if (rq) { + requires_scan = ucl_object_toboolean(rq); + } + neural_rule_requires_scan[rule_name] = requires_scan; + } + neural_cfg_loaded = true; + } + } + ucl_object_unref(result); + } + else if (err) { + /* Do not fail the whole run if config not available */ + neural_cfg_loaded = false; + } + + rspamd_client_destroy(conn); +} + +static void +rspamc_fetch_neural_config(struct ev_loop *ev_base, GQueue *attrs) +{ + /* Build connection to controller port */ + const char *p; + uint16_t port; + std::string hostbuf; + GError *err = nullptr; + + if (connect_str[0] == '[') { + p = strrchr(connect_str, ']'); + if (p != nullptr) { + hostbuf.assign(connect_str + 1, (std::size_t) (p - connect_str - 1)); + p++; + } + else { + p = connect_str; + } + } + else { + p = connect_str; + } + + p = strrchr(p, ':'); + if (hostbuf.empty()) { + if (p != nullptr) { + hostbuf.assign(connect_str, (std::size_t) (p - connect_str)); + } + else { + hostbuf.assign(connect_str); + } + } + + if (p != nullptr) { + port = strtoul(p + 1, nullptr, 10); + } + else { + /* Default to controller port if not specified */ + port = DEFAULT_CONTROL_PORT; + } + + auto *conn = rspamd_client_init(http_ctx, ev_base, hostbuf.c_str(), port, timeout, pubkey); + if (conn != nullptr) { + /* Minimal headers; reuse attrs so users can pass Password, etc. */ + rspamd_client_command(conn, + "/plugins/neural/config", + attrs, + nullptr, + rspamc_neural_config_cb, + nullptr, + compressed, + dictionary, + "neural_config", + &err); + } +} + int main(int argc, char **argv, char **env) { auto *kwattrs = g_queue_new(); @@ -2398,6 +2554,13 @@ int main(int argc, char **argv, char **env) add_options(kwattrs); auto cmd = maybe_cmd.value(); + /* Preload neural config once if we are going to use neural_learn */ + if (cmd.cmd == RSPAMC_COMMAND_NEURAL_LEARN && !neural_cfg_loaded) { + rspamc_fetch_neural_config(event_loop, kwattrs); + /* Drive loop once to complete config request */ + ev_loop(event_loop, 0); + } + if (start_argc == argc && files_list == nullptr) { /* Do command without input or with stdin */ if (empty_input) { diff --git a/src/client/rspamdclient.c b/src/client/rspamdclient.c index 24240d3c25..adffcf43f0 100644 --- a/src/client/rspamdclient.c +++ b/src/client/rspamdclient.c @@ -461,8 +461,14 @@ rspamd_client_command(struct rspamd_client_connection *conn, */ rspamd_http_message_add_header(req->msg, "Accept", "application/msgpack"); - req->msg->url = rspamd_fstring_append(req->msg->url, "/", 1); - req->msg->url = rspamd_fstring_append(req->msg->url, command, strlen(command)); + /* Append path ensuring a single leading slash */ + if (command != NULL && command[0] == '/') { + req->msg->url = rspamd_fstring_append(req->msg->url, command, strlen(command)); + } + else { + req->msg->url = rspamd_fstring_append(req->msg->url, "/", 1); + req->msg->url = rspamd_fstring_append(req->msg->url, command ? command : "", command ? strlen(command) : 0); + } conn->req = req; conn->start_time = rspamd_get_ticks(FALSE); diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 633a45854a..1e8a135f18 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -69,20 +69,20 @@ local function new_ann_profile(task, rule, set, version) local function add_cb(err, _) if err then rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s', - rule.prefix, set.name, profile.redis_key, err) + rule.prefix, set.name, profile.redis_key, err) else rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s', - rule.prefix, set.name, profile.redis_key) + rule.prefix, set.name, profile.redis_key) end end lua_redis.redis_make_request(task, - rule.redis, - nil, - true, -- is write - add_cb, --callback - 'ZADD', -- command - { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } + rule.redis, + nil, + true, -- is write + add_cb, --callback + 'ZADD', -- command + { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } ) return profile @@ -103,18 +103,18 @@ local function ann_scores_filter(task) profile = set.ann else lua_util.debugm(N, task, 'no ann loaded for %s:%s', - rule.prefix, set.name) + rule.prefix, set.name) end else lua_util.debugm(N, task, 'no ann defined in %s for settings id %s', - rule.prefix, sid) + rule.prefix, sid) end if ann then local function after_features(vec, meta) if profile.providers_digest and meta and meta.digest and profile.providers_digest ~= meta.digest then lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply', - rule.prefix, set.name) + rule.prefix, set.name) vec = nil end @@ -131,7 +131,7 @@ local function ann_scores_filter(task) local symscore = string.format('%.3f', score) task:cache_set(rule.prefix .. '_neural_score', score) lua_util.debugm(N, task, '%s:%s:%s ann score: %s', - rule.prefix, set.name, set.ann.version, symscore) + rule.prefix, set.name, set.ann.version, symscore) if score > 0 then local result = score @@ -152,8 +152,8 @@ local function ann_scores_filter(task) end else lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)', - rule.prefix, set.name, set.ann.version, symscore, - spam_threshold) + rule.prefix, set.name, set.ann.version, symscore, + spam_threshold) end else local result = -(score) @@ -174,8 +174,8 @@ local function ann_scores_filter(task) end else lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)', - rule.prefix, set.name, set.ann.version, result, - ham_threshold) + rule.prefix, set.name, set.ann.version, result, + ham_threshold) end end end @@ -194,21 +194,40 @@ local function ann_push_task_result(rule, task, verdict, score, set) local train_opts = rule.train local learn_spam, learn_ham local skip_reason = 'unknown' + local manual_train = false + + -- First, honor explicit manual training header if present + do + local hdr = task:get_request_header('ANN-Train') + if hdr then + local hv = tostring(hdr):lower() + lua_util.debugm(N, task, 'found ANN-Train header, enable manual train mode', hv) + if hv == 'spam' then + learn_spam = true + manual_train = true + elseif hv == 'ham' then + learn_ham = true + manual_train = true + else + skip_reason = 'no explicit header' + end + end + end - if not train_opts.store_pool_only and train_opts.autotrain then + if not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then if train_opts.spam_score then learn_spam = score >= train_opts.spam_score if not learn_spam then skip_reason = string.format('score < spam_score: %f < %f', - score, train_opts.spam_score) + score, train_opts.spam_score) end else learn_spam = verdict == 'spam' or verdict == 'junk' if not learn_spam then skip_reason = string.format('verdict: %s', - verdict) + verdict) end end @@ -216,29 +235,18 @@ local function ann_push_task_result(rule, task, verdict, score, set) learn_ham = score <= train_opts.ham_score if not learn_ham then skip_reason = string.format('score > ham_score: %f > %f', - score, train_opts.ham_score) + score, train_opts.ham_score) end else learn_ham = verdict == 'ham' if not learn_ham then skip_reason = string.format('verdict: %s', - verdict) + verdict) end end else - -- Train by request header - local hdr = task:get_request_header('ANN-Train') - - if hdr then - if hdr:lower() == 'spam' then - learn_spam = true - elseif hdr:lower() == 'ham' then - learn_ham = true - else - skip_reason = 'no explicit header' - end - elseif train_opts.store_pool_only then + if train_opts.store_pool_only and not manual_train then local ucl = require "ucl" learn_ham = false learn_spam = false @@ -272,7 +280,7 @@ local function ann_push_task_result(rule, task, verdict, score, set) if not err and type(data) == 'table' then local nspam, nham = data[1], data[2] - if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then + if manual_train or neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then local vec if rule.providers and #rule.providers > 0 then -- Note: this training path remains sync for now; vectors are pushed when computed @@ -288,41 +296,41 @@ local function ann_push_task_result(rule, task, verdict, score, set) local function learn_vec_cb(redis_err) if redis_err then rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', - rule.prefix, set.name, redis_err) + rule.prefix, set.name, redis_err) else lua_util.debugm(N, task, - "add train data for ANN rule " .. - "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", - rule.prefix, set.name, learn_type, #vec, target_key, #str) + "add train data for ANN rule " .. + "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", + rule.prefix, set.name, learn_type, #vec, target_key, #str) end end lua_redis.redis_make_request(task, - rule.redis, - nil, - true, -- is write - learn_vec_cb, --callback - 'SADD', -- command - { target_key, str } -- arguments + rule.redis, + nil, + true, -- is write + learn_vec_cb, --callback + 'SADD', -- command + { target_key, str } -- arguments ) else lua_util.debugm(N, task, - "do not add %s train data for ANN rule " .. - "%s:%s", - learn_type, rule.prefix, set.name) + "do not add %s train data for ANN rule " .. + "%s:%s", + learn_type, rule.prefix, set.name) end else if err then rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s', - rule.prefix, set.name, err) + rule.prefix, set.name, err) elseif type(data) == 'string' then -- nil return value rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s", - learn_type, rule.prefix, set.name, set.ann.redis_key, data) + learn_type, rule.prefix, set.name, set.ann.redis_key, data) else rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' .. - 'please remove this key from Redis manually if you perform upgrade from the previous version', - rule.prefix, set.name, set.ann.redis_key, type(data)) + 'please remove this key from Redis manually if you perform upgrade from the previous version', + rule.prefix, set.name, set.ann.redis_key, type(data)) end end end @@ -333,25 +341,25 @@ local function ann_push_task_result(rule, task, verdict, score, set) -- Need to create or load a profile corresponding to the current configuration set.ann = new_ann_profile(task, rule, set, 0) lua_util.debugm(N, task, - 'requested new profile for %s, set.ann is missing', - set.name) + 'requested new profile for %s, set.ann is missing', + set.name) end lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len, - { task = task, is_write = false }, - vectors_len_cb, - { - set.ann.redis_key, - }) + { task = task, is_write = false }, + vectors_len_cb, + { + set.ann.redis_key, + }) else lua_util.debugm(N, task, - 'do not push data: train condition not satisfied; reason: not checked existing ANNs') + 'do not push data: train condition not satisfied; reason: not checked existing ANNs') end else lua_util.debugm(N, task, - 'do not push data to key %s: train condition not satisfied; reason: %s', - (set.ann or {}).redis_key, - skip_reason) + 'do not push data to key %s: train condition not satisfied; reason: %s', + (set.ann or {}).redis_key, + skip_reason) end end @@ -376,16 +384,16 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) local function redis_ham_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', - ann_key, err) + ann_key, err) -- Unlock on error lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - neural_common.gen_unlock_cb(rule, set, ann_key), --callback - 'HDEL', -- command - { ann_key, 'lock' } + rspamd_config, + rule.redis, + nil, + true, -- is write + neural_common.gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + { ann_key, 'lock' } ) else -- Decompress and convert to numbers each training vector @@ -406,29 +414,29 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) local function redis_spam_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s', - ann_key, err) + ann_key, err) -- Unlock ANN on error lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - neural_common.gen_unlock_cb(rule, set, ann_key), --callback - 'HDEL', -- command - { ann_key, 'lock' } + rspamd_config, + rule.redis, + nil, + true, -- is write + neural_common.gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + { ann_key, 'lock' } ) else -- Decompress and convert to numbers each training vector spam_elts = process_training_vectors(data) -- Now get ham vectors... lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_ham_cb, --callback - 'SMEMBERS', -- command - { ann_key .. '_ham_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_ham_cb, --callback + 'SMEMBERS', -- command + { ann_key .. '_ham_set' } ) end end @@ -436,33 +444,33 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) local function redis_lock_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s', - ann_key, err) + ann_key, err) elseif type(data) == 'number' and data == 1 then -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_spam_cb, --callback - 'SMEMBERS', -- command - { ann_key .. '_spam_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_spam_cb, --callback + 'SMEMBERS', -- command + { ann_key .. '_spam_set' } ) rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning', - rule.prefix, set.name, ann_key) + rule.prefix, set.name, ann_key) else local lock_tm = tonumber(data[1]) rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' .. - 'locked by another host %s at %s', rule.prefix, set.name, ann_key, - data[2], os.date('%c', lock_tm)) + 'locked by another host %s at %s', rule.prefix, set.name, ann_key, + data[2], os.date('%c', lock_tm)) end end -- Check if we are already learning this network if set.learning_spawned then rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', - ann_key) + ann_key) return end @@ -470,14 +478,14 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when -- ANN is locked by another host (or a process, meh) lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock, - { ev_base = ev_base, is_write = true }, - redis_lock_cb, - { - ann_key, - tostring(os.time()), - tostring(math.max(10.0, rule.watch_interval * 2)), - rspamd_util.get_hostname() - }) + { ev_base = ev_base, is_write = true }, + redis_lock_cb, + { + ann_key, + tostring(os.time()), + tostring(math.max(10.0, rule.watch_interval * 2)), + rspamd_util.get_hostname() + }) end -- This function loads new ann from Redis @@ -492,7 +500,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) local function data_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', - ann_key, err) + ann_key, err) else if type(data) == 'table' then if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then @@ -501,7 +509,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) if _err or not ann_data then rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', - rule.prefix .. ':' .. set.name, ann_key, _err) + rule.prefix .. ':' .. set.name, ann_key, _err) return else ann = rspamd_kann.load(ann_data) @@ -525,26 +533,26 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) end -- Also update rank for the loaded ANN to avoid removal lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - rank_cb, --callback - 'ZADD', -- command - { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } + rspamd_config, + rule.redis, + nil, + true, -- is write + rank_cb, --callback + 'ZADD', -- command + { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } ) rspamd_logger.infox(rspamd_config, - 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', - rule.prefix, set.name, ann_key, #data[1], profile.version) + 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #data[1], profile.version) else rspamd_logger.errx(rspamd_config, - 'cannot unpack/deserialise ANN for %s:%s from Redis key %s', - rule.prefix, set.name, ann_key) + 'cannot unpack/deserialise ANN for %s:%s from Redis key %s', + rule.prefix, set.name, ann_key) end end else lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s', - rule.prefix, set.name, ann_key) + rule.prefix, set.name, ann_key) end if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then @@ -556,8 +564,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) local roc_thresholds = parser:get_object() set.ann.roc_thresholds = roc_thresholds rspamd_logger.infox(rspamd_config, - 'loaded ROC thresholds for %s:%s; version=%s', - rule.prefix, set.name, profile.version) + 'loaded ROC thresholds for %s:%s; version=%s', + rule.prefix, set.name, profile.version) rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds) end end @@ -570,19 +578,19 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) -- We can use PCA set.ann.pca = rspamd_tensor.load(pca_data) rspamd_logger.infox(rspamd_config, - 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s', - rule.prefix, set.name, ann_key, #data[3], profile.version) + 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #data[3], profile.version) else -- no need in pca, why is it there? rspamd_logger.warnx(rspamd_config, - 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined', - rule.prefix, set.name, ann_key) + 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined', + rule.prefix, set.name, ann_key) end else -- pca can be missing merely if we have no max_inputs if rule.max_inputs then rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s', - rule.prefix, set.name, ann_key, _err) + rule.prefix, set.name, ann_key, _err) set.ann.ann = nil else -- It is okay @@ -611,19 +619,19 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) end else lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s', - rule.prefix, set.name, ann_key) + rule.prefix, set.name, ann_key) end end end lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - data_cb, --callback - 'HMGET', -- command - { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments - { opaque_data = true } + rspamd_config, + rule.redis, + nil, + false, -- is write + data_cb, --callback + 'HMGET', -- command + { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments + { opaque_data = true } ) end @@ -659,34 +667,34 @@ local function process_existing_ann(_, ev_base, rule, set, profiles) if set.ann.version < sel_elt.version then -- Load new ann rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' .. - 'our version = %s, remote version = %s', - rule.prefix .. ':' .. set.name, - set.ann.version, - sel_elt.version) + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) load_new_ann(rule, ev_base, set, sel_elt, min_diff) else lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' .. - 'our version = %s, remote version = %s', - rule.prefix .. ':' .. set.name, - set.ann.version, - sel_elt.version) + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) end else -- We have some different ANN, so we need to compare distance if set.ann.distance > min_diff then -- Load more specific ANN rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' .. - 'our distance = %s, remote distance = %s', - rule.prefix .. ':' .. set.name, - set.ann.distance, - min_diff) + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) load_new_ann(rule, ev_base, set, sel_elt, min_diff) else lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' .. - 'our distance = %s, remote distance = %s', - rule.prefix .. ':' .. set.name, - set.ann.distance, - min_diff) + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) end end else @@ -724,14 +732,14 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) local ann_key = sel_elt.redis_key lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained", - ann_key) + ann_key) -- Create continuation closure local redis_len_cb_gen = function(cont_cb, what, is_final) return function(err, data) if err then rspamd_logger.errx(rspamd_config, - 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) + 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) elseif data and type(data) == 'number' or type(data) == 'string' then local ntrains = tonumber(data) or 0 lens[what] = ntrains @@ -752,31 +760,31 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) end if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then lua_util.debugm(N, rspamd_config, - 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', - ann_key, lens, rule.train.max_trains, what) + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) cont_cb() else lua_util.debugm(N, rspamd_config, - 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', - ann_key, what, lens, rule.train.max_trains) + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) end else -- Probabilistic mode, just ensure that at least one vector is okay if min_len > 0 and max_len >= rule.train.max_trains then lua_util.debugm(N, rspamd_config, - 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', - ann_key, lens, rule.train.max_trains, what) + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) cont_cb() else lua_util.debugm(N, rspamd_config, - 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', - ann_key, what, lens, rule.train.max_trains) + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) end end else lua_util.debugm(N, rspamd_config, - 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', - what, ann_key, ntrains, rule.train.max_trains) + 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', + what, ann_key, ntrains, rule.train.max_trains) cont_cb() end end @@ -785,32 +793,32 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) local function initiate_train() rspamd_logger.infox(rspamd_config, - 'need to learn ANN %s after %s required learn vectors', - ann_key, lens) + 'need to learn ANN %s after %s required learn vectors', + ann_key, lens) do_train_ann(worker, ev_base, rule, set, ann_key) end -- Spam vector is OK, check ham vector length local function check_ham_len() lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_len_cb_gen(initiate_train, 'ham', true), --callback - 'SCARD', -- command - { ann_key .. '_ham_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(initiate_train, 'ham', true), --callback + 'SCARD', -- command + { ann_key .. '_ham_set' } ) end lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_len_cb_gen(check_ham_len, 'spam', false), --callback - 'SCARD', -- command - { ann_key .. '_spam_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(check_ham_len, 'spam', false), --callback + 'SCARD', -- command + { ann_key .. '_spam_set' } ) end end @@ -823,7 +831,7 @@ local function load_ann_profile(element) local res, ucl_err = parser:parse_string(element) if not res then rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s', - ucl_err) + ucl_err) return nil else local profile = parser:get_object() @@ -843,11 +851,11 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what) local function members_cb(err, data) if err then rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', - err) + err) set.can_store_vectors = true elseif type(data) == 'table' then lua_util.debugm(N, cfg, '%s: process element %s:%s', - what, rule.prefix, set.name) + what, rule.prefix, set.name) process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data)) set.can_store_vectors = true end @@ -859,13 +867,13 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what) -- Select the most appropriate to our profile but it should not differ by more -- than 30% of symbols lua_redis.redis_make_request_taskless(ev_base, - cfg, - rule.redis, - nil, - false, -- is write - members_cb, --callback - 'ZREVRANGE', -- command - { set.prefix, '0', tostring(settings.max_profiles) } -- arguments + cfg, + rule.redis, + nil, + false, -- is write + members_cb, --callback + 'ZREVRANGE', -- command + { set.prefix, '0', tostring(settings.max_profiles) } -- arguments ) end end -- Cycle over all settings @@ -879,23 +887,23 @@ local function cleanup_anns(rule, cfg, ev_base) local function invalidate_cb(err, data) if err then rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s', - err) + err) elseif type(data) == 'table' then for _, expired in ipairs(data) do local profile = load_ann_profile(expired) rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s', - rule.prefix .. ':' .. set.name, - profile.redis_key, - profile.version) + rule.prefix .. ':' .. set.name, + profile.redis_key, + profile.version) end end end if type(set) == 'table' then lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate, - { ev_base = ev_base, is_write = true }, - invalidate_cb, - { set.prefix, tostring(settings.max_profiles) }) + { ev_base = ev_base, is_write = true }, + invalidate_cb, + { set.prefix, tostring(settings.max_profiles) }) end end end @@ -914,14 +922,14 @@ local function ann_push_vector(task) if verdict == 'passthrough' then lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)', - verdict, score) + verdict, score) return end if score ~= score then lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)', - verdict) + verdict) return end @@ -991,7 +999,7 @@ for k, r in pairs(rules) do if rule_elt.max_inputs and not has_blas then rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in', - rule_elt.name, rule_elt.max_inputs) + rule_elt.name, rule_elt.max_inputs) rule_elt.max_inputs = nil end @@ -1003,7 +1011,7 @@ for k, r in pairs(rules) do end if (pcfg.type == 'llm' or pcfg.name == 'llm') and not (pcfg.model or (rspamd_config:get_all_opt('gpt') or {}).model) then rspamd_logger.errx(rspamd_config, - 'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k) + 'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k) end end end @@ -1054,21 +1062,21 @@ for _, rule in pairs(settings.rules) do rspamd_config:add_on_load(function(cfg, ev_base, worker) if worker:is_scanner() then rspamd_config:add_periodic(ev_base, 0.0, - function(_, _) - return check_anns(worker, cfg, ev_base, rule, process_existing_ann, - 'try_load_ann') - end) + function(_, _) + return check_anns(worker, cfg, ev_base, rule, process_existing_ann, + 'try_load_ann') + end) end if worker:is_primary_controller() then -- We also want to train neural nets when they have enough data rspamd_config:add_periodic(ev_base, 0.0, - function(_, _) - -- Clean old ANNs - cleanup_anns(rule, cfg, ev_base) - return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann, - 'try_train_ann') - end) + function(_, _) + -- Clean old ANNs + cleanup_anns(rule, cfg, ev_base) + return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann, + 'try_train_ann') + end) end end) end diff --git a/test/functional/lib/rspamd.robot b/test/functional/lib/rspamd.robot index f61998f46e..1bda5e5c2e 100644 --- a/test/functional/lib/rspamd.robot +++ b/test/functional/lib/rspamd.robot @@ -475,6 +475,15 @@ Run Dummy Https Wait Until Created /tmp/dummy_https.pid timeout=2 second Export Scoped Variables ${RSPAMD_SCOPE} DUMMY_HTTPS_PROC=${result} +Run Dummy Llm + ${result} = Start Process python3 ${RSPAMD_TESTDIR}/util/dummy_llm.py 18080 + Wait Until Created /tmp/dummy_llm.pid timeout=2 second + Export Scoped Variables ${RSPAMD_SCOPE} DUMMY_LLM_PROC=${result} + +Dummy Llm Teardown + Terminate Process ${DUMMY_LLM_PROC} + Wait For Process ${DUMMY_LLM_PROC} + Dummy Http Teardown Terminate Process ${DUMMY_HTTP_PROC} Wait For Process ${DUMMY_HTTP_PROC}