local N = "neural.llm"
local function select_text(task)
- local content = llm_common.build_llm_input(task)
- return content
+ local input_tbl = llm_common.build_llm_input(task)
+ return input_tbl
end
local function compose_llm_settings(pcfg)
return nil
end
- local content = select_text(task)
- if not content or #content == 0 then
+ local input_tbl = select_text(task)
+ if not input_tbl then
rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip')
return nil
end
+ -- Build request input string (text then optional subject), keeping rspamd_text intact
+ local input_string = input_tbl.text or ''
+ if input_tbl.subject and input_tbl.subject ~= '' then
+ input_string = input_string .. "\nSubject: " .. input_tbl.subject
+ end
+
local body
if llm.type == 'openai' then
- body = { model = llm.model, input = content }
+ body = { model = llm.model, input = input_string }
elseif llm.type == 'ollama' then
- body = { model = llm.model, prompt = content }
+ body = { model = llm.model, prompt = input_string }
else
rspamd_logger.debugm(N, task, 'unsupported llm type: %s', llm.type)
return nil
end
- -- Redis cache: use content hash + model + provider as key
+ -- Redis cache: hash the final input string only (IUF is trivial here)
local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, {
cache_prefix = llm.cache_prefix,
cache_ttl = llm.cache_ttl,
cache_use_hashing = llm.cache_use_hashing,
}, N)
- -- Use a stable key based on content digest
local hasher = require 'rspamd_cryptobox_hash'
- local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(content):hex())
+ local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(input_string):hex())
local function do_request_and_cache()
local headers = { ['Content-Type'] = 'application/json' }
use_gzip = true,
}
- local err, data = rspamd_http.request(http_params)
- if err then
- rspamd_logger.debugm(N, task, 'llm request failed: %s', err)
- return nil
- end
-
- local parser = ucl.parser()
- local ok, perr = parser:parse_string(data.content)
- if not ok then
- rspamd_logger.debugm(N, task, 'cannot parse llm response: %s', perr)
- return nil
- end
-
- local parsed = parser:get_object()
- local embedding = extract_embedding(llm.type, parsed)
- if not embedding or #embedding == 0 then
- rspamd_logger.debugm(N, task, 'no embedding in llm response')
- return nil
- end
-
- for i = 1, #embedding do
- embedding[i] = tonumber(embedding[i]) or 0.0
- end
-
- lua_cache.cache_set(task, key, { e = embedding }, cache_ctx)
- return embedding
- end
+ local function http_cb(err, code, resp, _)
+ if err then
+ rspamd_logger.debugm(N, task, 'llm http error: %s', err)
+ return
+ end
+ if code ~= 200 or not resp then
+ rspamd_logger.debugm(N, task, 'llm bad http code: %s', code)
+ return
+ end
- -- Try cache first
- local cached_result
- local done = false
- lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0,
- function(_)
- -- Uncached: perform request synchronously and store
- cached_result = do_request_and_cache()
- done = true
- end,
- function(_, err, data)
- if data and data.e then
- cached_result = data.e
+ local parser = ucl.parser()
+ local ok, perr = parser:parse_string(resp)
+ if not ok then
+ rspamd_logger.debugm(N, task, 'llm cannot parse reply: %s', perr)
+ return
+ end
+ local parsed = parser:get_object()
+ local emb = extract_embedding(llm.type, parsed)
+ if type(emb) == 'table' then
+ cache_ctx:set_cached(key, emb)
+ neural_common.append_provider_vector(ctx, { provider = 'llm', vector = emb })
end
- done = true
end
- )
- if not done then
- -- Fallback: ensure we still do the request now (cache API is async-ready, but we need sync path)
- cached_result = do_request_and_cache()
+ rspamd_http.request(http_params, http_cb)
end
- local embedding = cached_result
- if not embedding then
- return nil
+ local cached = cache_ctx:get_cached(key)
+ if type(cached) == 'table' then
+ neural_common.append_provider_vector(ctx, { provider = 'llm', vector = cached })
+ return
end
- local meta = {
- name = pcfg.name or 'llm',
- type = 'llm',
- dim = #embedding,
- weight = pcfg.weight or 1.0,
- model = llm.model,
- provider = llm.type,
- }
-
- return embedding, meta
+ do_request_and_cache()
end,
- collect_async = function(task, ctx, cont)
- local pcfg = ctx.config or {}
- local llm = compose_llm_settings(pcfg)
- if not llm.model then
- return cont(nil)
- end
- local content = select_text(task)
- if not content or #content == 0 then
- return cont(nil)
- end
- local body
- if llm.type == 'openai' then
- body = { model = llm.model, input = content }
- elseif llm.type == 'ollama' then
- body = { model = llm.model, prompt = content }
- else
- return cont(nil)
- end
- local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, {
- cache_prefix = llm.cache_prefix,
- cache_ttl = llm.cache_ttl,
- cache_format = 'messagepack',
- cache_hash_len = llm.cache_hash_len,
- cache_use_hashing = llm.cache_use_hashing,
- }, N)
- local hasher = require 'rspamd_cryptobox_hash'
- local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(content):hex())
-
- local function finish_with_embedding(embedding)
- if not embedding then return cont(nil) end
- for i = 1, #embedding do
- embedding[i] = tonumber(embedding[i]) or 0.0
- end
- cont(embedding, {
- name = pcfg.name or 'llm',
- type = 'llm',
- dim = #embedding,
- weight = pcfg.weight or 1.0,
- model = llm.model,
- provider = llm.type,
- })
- end
-
- local function request_and_cache()
- local headers = { ['Content-Type'] = 'application/json' }
- if llm.type == 'openai' and llm.api_key then
- headers['Authorization'] = 'Bearer ' .. llm.api_key
- end
- local http_params = {
- url = llm.url,
- mime_type = 'application/json',
- timeout = llm.timeout,
- log_obj = task,
- headers = headers,
- body = ucl.to_format(body, 'json-compact', true),
- task = task,
- method = 'POST',
- use_gzip = true,
- callback = function(err, _, data)
- if err then return cont(nil) end
- local parser = ucl.parser()
- local ok = parser:parse_text(data)
- if not ok then return cont(nil) end
- local parsed = parser:get_object()
- local embedding = extract_embedding(llm.type, parsed)
- if embedding and cache_ctx then
- lua_cache.cache_set(task, key, { e = embedding }, cache_ctx)
- end
- finish_with_embedding(embedding)
- end,
- }
- rspamd_http.request(http_params)
- end
-
- if cache_ctx then
- lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0,
- function(_)
- request_and_cache()
- end,
- function(_, err, data)
- if data and data.e then
- finish_with_embedding(data.e)
- else
- request_and_cache()
- end
- end
- )
- else
- request_and_cache()
- end
- end
})
local ts = require("tableshape").types
local ucl = require "ucl"
local lua_util = require "lua_util"
+local rspamd_util = require "rspamd_util"
+local lua_redis = require "lua_redis"
+local rspamd_logger = require "rspamd_logger"
local E = {}
+local N = 'neural'
-- Controller neural plugin
}
local function handle_learn(task, conn)
+ lua_util.debugm(N, task, 'controller.neural: learn called')
local parser = ucl.parser()
local ok, err = parser:parse_text(task:get_rawbody())
if not ok then
worker = task:get_worker(),
}
+ lua_util.debugm(N, task, 'controller.neural: learn scheduled for rule=%s', rule_name)
conn:send_string('{"success" : true}')
end
local function handle_status(task, conn, req_params)
+ lua_util.debugm(N, task, 'controller.neural: status called')
local out = {
rules = {},
}
fusion = rule.fusion,
max_inputs = rule.max_inputs,
settings = {},
+ requires_scan = false,
}
+ -- Default: if no providers configured, assume symbols (full scan required)
+ local has_providers = type(rule.providers) == 'table' and #rule.providers > 0
+ if not has_providers then
+ r.requires_scan = true
+ else
+ for _, p in ipairs(rule.providers) do
+ if p.type == 'symbols' then
+ r.requires_scan = true
+ break
+ end
+ end
+ end
for sid, set in pairs(rule.settings or {}) do
if type(set) == 'table' then
local s = {
conn:send_ucl({ success = true, data = out })
end
+-- Return compact configuration for clients (e.g. rspamc) to plan learning
+local function handle_config(task, conn, req_params)
+ lua_util.debugm(N, task, 'controller.neural: config called')
+ local out = {
+ rules = {},
+ }
+
+ for name, rule in pairs(neural_common.settings.rules) do
+ local requires_scan = false
+ local has_providers = type(rule.providers) == 'table' and #rule.providers > 0
+ if not has_providers then
+ requires_scan = true
+ else
+ for _, p in ipairs(rule.providers) do
+ if p.type == 'symbols' then
+ requires_scan = true
+ break
+ end
+ end
+ end
+
+ local r = {
+ requires_scan = requires_scan,
+ providers = {},
+ recommended_path = requires_scan and '/checkv2' or '/controller/neural/learn_message',
+ settings = {},
+ }
+
+ if has_providers then
+ for _, p in ipairs(rule.providers) do
+ r.providers[#r.providers + 1] = { type = p.type }
+ end
+ end
+
+ for sid, set in pairs(rule.settings or {}) do
+ if type(set) == 'table' then
+ r.settings[#r.settings + 1] = set.name
+ end
+ end
+
+ out.rules[name] = r
+ end
+
+ conn:send_ucl({ success = true, data = out })
+end
+
+-- Train directly from a message for providers that don't require full /checkv2
+-- Headers:
+-- - ANN-Train or Class: 'spam' | 'ham'
+-- - Rule: rule name (optional, default 'default')
+local function handle_learn_message(task, conn)
+ lua_util.debugm(N, task, 'controller.neural: learn_message called')
+ local cls = task:get_request_header('ANN-Train') or task:get_request_header('Class')
+ if not cls then
+ conn:send_error(400, 'missing class header (ANN-Train or Class)')
+ return
+ end
+
+ local learn_type = tostring(cls):lower()
+ if learn_type ~= 'spam' and learn_type ~= 'ham' then
+ conn:send_error(400, 'unsupported class (expected spam or ham)')
+ return
+ end
+
+ local rule_name = task:get_request_header('Rule') or 'default'
+ local rule = neural_common.settings.rules[rule_name]
+ if not rule then
+ conn:send_error(400, 'unknown rule')
+ return
+ end
+
+ -- If no providers or symbols provider configured, require full scan path
+ local has_providers = type(rule.providers) == 'table' and #rule.providers > 0
+ if not has_providers then
+ lua_util.debugm(N, task, 'controller.neural: learn_message refused: no providers (assume symbols) for rule=%s',
+ rule_name)
+ conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured)')
+ return
+ end
+ for _, p in ipairs(rule.providers) do
+ if p.type == 'symbols' then
+ lua_util.debugm(N, task, 'controller.neural: learn_message refused due to symbols provider for rule=%s', rule_name)
+ conn:send_error(400, 'rule requires full /checkv2 scan (symbols provider present)')
+ return
+ end
+ end
+
+ local set = neural_common.get_rule_settings(task, rule)
+ if not set or not set.ann or not set.ann.redis_key then
+ conn:send_error(400, 'invalid rule settings for learning')
+ return
+ end
+
+ local function after_collect(vec)
+ lua_util.debugm(N, task, 'controller.neural: learn_message after_collect, vector=%s', type(vec))
+ if not vec then
+ vec = neural_common.result_to_vector(task, set)
+ end
+
+ if type(vec) ~= 'table' then
+ conn:send_error(500, 'failed to build training vector')
+ return
+ end
+
+ local compressed = rspamd_util.zstd_compress(table.concat(vec, ';'))
+ local target_key = string.format('%s_%s_set', set.ann.redis_key, learn_type)
+
+ local function learn_vec_cb(redis_err)
+ if redis_err then
+ rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+ rule.prefix, set.name, redis_err)
+ conn:send_error(500, 'cannot store train vector')
+ else
+ lua_util.debugm(N, task, 'controller.neural: stored train vector for rule=%s key=%s bytes=%s', rule_name,
+ target_key, #compressed)
+ conn:send_ucl({ success = true, stored = #compressed, key = target_key })
+ end
+ end
+
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ true,
+ learn_vec_cb,
+ 'SADD',
+ { target_key, compressed }
+ )
+ end
+
+ if rule.providers and #rule.providers > 0 then
+ lua_util.debugm(N, task, 'controller.neural: collecting features for rule=%s', rule_name)
+ neural_common.collect_features_async(task, rule, set, 'train', after_collect)
+ else
+ -- Should not reach here due to early return
+ conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured)')
+ end
+end
+
local function handle_train(task, conn, req_params)
local rule_name = req_params.rule or 'default'
local rule = neural_common.settings.rules[rule_name]
enable = true,
need_task = true,
},
+ config = {
+ handler = handle_config,
+ enable = true,
+ need_task = false,
+ },
+ learn_message = {
+ handler = handle_learn_message,
+ enable = true,
+ need_task = true,
+ },
status = {
handler = handle_status,
enable = false,
#include <cstdio>
#include <cmath>
#include <locale>
+#include <unordered_map>
#include "frozen/string.h"
#include "frozen/unordered_map.h"
static const char *helo = nullptr;
static const char *hostname = nullptr;
static const char *classifier = nullptr;
-static const char *learn_class_name = nullptr;
+static std::string learn_class_name;
static const char *local_addr = nullptr;
static const char *execute = nullptr;
static const char *sort = nullptr;
static const char *queue_id = nullptr;
static const char *log_tag = nullptr;
static std::string settings;
+static std::string neural_train;
+static std::string neural_rule;
+static bool neural_cfg_loaded = false;
+static std::unordered_map<std::string, bool> neural_rule_requires_scan;
std::vector<GPid> children;
static GPatternSpec **exclude_compiled = nullptr;
RSPAMC_COMMAND_LEARN_SPAM,
RSPAMC_COMMAND_LEARN_HAM,
RSPAMC_COMMAND_LEARN_CLASS,
+ RSPAMC_COMMAND_NEURAL_LEARN,
RSPAMC_COMMAND_FUZZY_ADD,
RSPAMC_COMMAND_FUZZY_DEL,
RSPAMC_COMMAND_FUZZY_DELHASH,
rspamc_command{
.cmd = RSPAMC_COMMAND_SYMBOLS,
.name = "symbols",
- .path = "checkv2",
+ .path = "/checkv2",
.description = "scan message and show symbols (default command)",
.is_controller = FALSE,
.is_privileged = FALSE,
rspamc_command{
.cmd = RSPAMC_COMMAND_LEARN_SPAM,
.name = "learn_spam",
- .path = "learnspam",
+ .path = "/learnspam",
.description = "learn message as spam",
.is_controller = TRUE,
.is_privileged = TRUE,
rspamc_command{
.cmd = RSPAMC_COMMAND_LEARN_HAM,
.name = "learn_ham",
- .path = "learnham",
+ .path = "/learnham",
.description = "learn message as ham",
.is_controller = TRUE,
.is_privileged = TRUE,
rspamc_command{
.cmd = RSPAMC_COMMAND_LEARN_CLASS,
.name = "learn_class",
- .path = "learnclass",
+ .path = "/learnclass",
.description = "learn message as class",
.is_controller = TRUE,
.is_privileged = TRUE,
.need_input = TRUE,
.command_output_func = nullptr},
+ rspamc_command{
+ .cmd = RSPAMC_COMMAND_NEURAL_LEARN,
+ .name = "neural_learn",
+ .path = "/checkv2",
+ .description = "learn neural with a class (use neural_learn:<class>)",
+ .is_controller = FALSE,
+ .is_privileged = FALSE,
+ .need_input = TRUE,
+ .command_output_func = rspamc_symbols_output},
rspamc_command{
.cmd = RSPAMC_COMMAND_FUZZY_ADD,
.name = "fuzzy_add",
- .path = "fuzzyadd",
+ .path = "/fuzzyadd",
.description =
"add hashes from a message to the fuzzy storage (check -f and -w options for this command)",
.is_controller = TRUE,
rspamc_command{
.cmd = RSPAMC_COMMAND_FUZZY_DEL,
.name = "fuzzy_del",
- .path = "fuzzydel",
+ .path = "/fuzzydel",
.description =
"delete hashes from a message from the fuzzy storage (check -f option for this command)",
.is_controller = TRUE,
rspamc_command{
.cmd = RSPAMC_COMMAND_FUZZY_DELHASH,
.name = "fuzzy_delhash",
- .path = "fuzzydelhash",
+ .path = "/fuzzydelhash",
.description =
"delete a hash from fuzzy storage (check -f option for this command)",
.is_controller = TRUE,
rspamc_command{
.cmd = RSPAMC_COMMAND_STAT,
.name = "stat",
- .path = "stat",
+ .path = "/stat",
.description = "show rspamd statistics",
.is_controller = TRUE,
.is_privileged = FALSE,
rspamc_command{
.cmd = RSPAMC_COMMAND_STAT_RESET,
.name = "stat_reset",
- .path = "statreset",
+ .path = "/statreset",
.description = "show and reset rspamd statistics (useful for graphs)",
.is_controller = TRUE,
.is_privileged = TRUE,
rspamc_command{
.cmd = RSPAMC_COMMAND_COUNTERS,
.name = "counters",
- .path = "counters",
+ .path = "/counters",
.description = "display rspamd symbols statistics",
.is_controller = TRUE,
.is_privileged = FALSE,
rspamc_command{
.cmd = RSPAMC_COMMAND_UPTIME,
.name = "uptime",
- .path = "auth",
+ .path = "/auth",
.description = "show rspamd uptime",
.is_controller = TRUE,
.is_privileged = FALSE,
rspamc_command{
.cmd = RSPAMC_COMMAND_ADD_SYMBOL,
.name = "add_symbol",
- .path = "addsymbol",
+ .path = "/addsymbol",
.description = "add or modify symbol settings in rspamd",
.is_controller = TRUE,
.is_privileged = TRUE,
rspamc_command{
.cmd = RSPAMC_COMMAND_ADD_ACTION,
.name = "add_action",
- .path = "addaction",
+ .path = "/addaction",
.description = "add or modify action settings",
.is_controller = TRUE,
.is_privileged = TRUE,
{"learn_spam", RSPAMC_COMMAND_LEARN_SPAM},
{"learn_ham", RSPAMC_COMMAND_LEARN_HAM},
{"learn_class", RSPAMC_COMMAND_LEARN_CLASS},
+ {"neural_learn", RSPAMC_COMMAND_NEURAL_LEARN},
{"fuzzy_add", RSPAMC_COMMAND_FUZZY_ADD},
{"fuzzy_del", RSPAMC_COMMAND_FUZZY_DEL},
{"fuzzy_delhash", RSPAMC_COMMAND_FUZZY_DELHASH},
std::string cmd_lc = rspamd_string_tolower(cmd);
- // Handle learn_class:classname syntax
- if (cmd_lc.find("learn_class:") == 0) {
+ /* Handle colon-suffixed commands in a unified way */
+ {
auto colon_pos = cmd_lc.find(':');
- if (colon_pos != std::string::npos && colon_pos + 1 < cmd_lc.length()) {
- auto class_name = cmd_lc.substr(colon_pos + 1);
- // Store class name globally for later use
- learn_class_name = g_strdup(class_name.c_str());
- // Return the learn_class command
- auto elt_it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&](const auto &item) {
- return item.cmd == RSPAMC_COMMAND_LEARN_CLASS;
- });
- if (elt_it != std::end(rspamc_commands)) {
- return *elt_it;
+ if (colon_pos != std::string::npos) {
+ auto base = cmd_lc.substr(0, colon_pos);
+ auto arg = cmd_lc.substr(colon_pos + 1);
+
+ if (!arg.empty()) {
+ auto find_cmd = [](enum rspamc_command_type t) -> std::optional<rspamc_command> {
+ const auto it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&t](const auto &item) {
+ return item.cmd == t;
+ });
+ if (it != std::end(rspamc_commands)) {
+ return *it;
+ }
+ return std::nullopt;
+ };
+
+ if (base == "learn_class") {
+ learn_class_name = arg;
+ return find_cmd(RSPAMC_COMMAND_LEARN_CLASS);
+ }
+ else if (base == "neural_learn") {
+ neural_train = arg; /* allow any class name, plugin validates */
+ return find_cmd(RSPAMC_COMMAND_NEURAL_LEARN);
+ }
}
}
- return std::nullopt;
}
auto ct = rspamd::find_map(str_map, std::string_view{cmd_lc});
add_client_header(opts, "Classifier", classifier);
}
- if (learn_class_name) {
- add_client_header(opts, "Class", learn_class_name);
+ if (!learn_class_name.empty()) {
+ add_client_header(opts, "Class", learn_class_name.c_str());
+ }
+
+ if (!neural_train.empty()) {
+ add_client_header(opts, "ANN-Train", neural_train.c_str());
}
if (weight != 0) {
add_client_header(opts,
hdr_view.substr(0, std::distance(std::begin(hdr_view), delim_pos)),
hdr_view.substr(std::distance(std::begin(hdr_view), delim_pos) + 1));
+ /* Capture Rule header for neural selection */
+ if (neural_rule.empty()) {
+ std::string name_copy{hdr_view.substr(0, std::distance(std::begin(hdr_view), delim_pos))};
+ std::transform(name_copy.begin(), name_copy.end(), name_copy.begin(), [](unsigned char c) { return std::tolower(c); });
+ if (name_copy == "rule") {
+ auto value_view = hdr_view.substr(std::distance(std::begin(hdr_view), delim_pos) + 1);
+ neural_rule.assign(value_view.begin(), value_view.end());
+ }
+ }
}
hdr++;
uint16_t port;
GError *err = nullptr;
std::string hostbuf;
+ std::string path_override;
if (connect_str[0] == '[') {
p = strrchr(connect_str, ']');
}
}
+ /* Dynamic path/port override for neural_learn based on fetched config */
+ if (cmd.cmd == RSPAMC_COMMAND_NEURAL_LEARN && neural_cfg_loaded) {
+ const auto &rule = !neural_rule.empty() ? neural_rule : std::string{"default"};
+ auto it = neural_rule_requires_scan.find(rule);
+ bool requires_scan = true;
+ if (it != neural_rule_requires_scan.end()) {
+ requires_scan = it->second;
+ }
+ if (!requires_scan) {
+ path_override = "/plugins/neural/learn_message";
+ if (p == nullptr) {
+ port = DEFAULT_CONTROL_PORT;
+ }
+ }
+ }
+
conn = rspamd_client_init(http_ctx, ev_base, hostbuf.c_str(), port, timeout, pubkey);
if (conn != nullptr) {
cbdata->filename = name;
if (cmd.need_input) {
- rspamd_client_command(conn, cmd.path, attrs, in, rspamc_client_cb,
+ const char *path = path_override.empty() ? cmd.path : path_override.c_str();
+ rspamd_client_command(conn, path, attrs, in, rspamc_client_cb,
cbdata, compressed, dictionary, cbdata->filename.c_str(), &err);
}
else {
+ const char *path = path_override.empty() ? cmd.path : path_override.c_str();
rspamd_client_command(conn,
- cmd.path,
+ path,
attrs,
nullptr,
rspamc_client_cb,
g_free(h);
}
+/* Fetch /controller/neural/config once per run and populate neural_cfg_loaded
+ * and neural_rule_requires_scan map. */
+static void rspamc_neural_config_cb(struct rspamd_client_connection *conn,
+ struct rspamd_http_message *msg,
+ const char *name, ucl_object_t *result, GString *input,
+ gpointer ud, double start_time, double send_time,
+ const char *body, gsize bodylen,
+ GError *err)
+{
+ /* Populate map: data.rules[rule].requires_scan */
+ if (result != nullptr) {
+ const auto *data = ucl_object_lookup(result, "data");
+ if (data && ucl_object_type(data) == UCL_OBJECT) {
+ const auto *rules = ucl_object_lookup(data, "rules");
+ if (rules && ucl_object_type(rules) == UCL_OBJECT) {
+ ucl_object_iter_t it = nullptr;
+ const ucl_object_t *cur;
+ while ((cur = ucl_object_iterate(rules, &it, true)) != nullptr) {
+ std::string rule_name = ucl_object_key(cur) ? ucl_object_key(cur) : "";
+ auto requires_scan = true;
+ const auto *rq = ucl_object_lookup(cur, "requires_scan");
+ if (rq) {
+ requires_scan = ucl_object_toboolean(rq);
+ }
+ neural_rule_requires_scan[rule_name] = requires_scan;
+ }
+ neural_cfg_loaded = true;
+ }
+ }
+ ucl_object_unref(result);
+ }
+ else if (err) {
+ /* Do not fail the whole run if config not available */
+ neural_cfg_loaded = false;
+ }
+
+ rspamd_client_destroy(conn);
+}
+
+static void
+rspamc_fetch_neural_config(struct ev_loop *ev_base, GQueue *attrs)
+{
+ /* Build connection to controller port */
+ const char *p;
+ uint16_t port;
+ std::string hostbuf;
+ GError *err = nullptr;
+
+ if (connect_str[0] == '[') {
+ p = strrchr(connect_str, ']');
+ if (p != nullptr) {
+ hostbuf.assign(connect_str + 1, (std::size_t) (p - connect_str - 1));
+ p++;
+ }
+ else {
+ p = connect_str;
+ }
+ }
+ else {
+ p = connect_str;
+ }
+
+ p = strrchr(p, ':');
+ if (hostbuf.empty()) {
+ if (p != nullptr) {
+ hostbuf.assign(connect_str, (std::size_t) (p - connect_str));
+ }
+ else {
+ hostbuf.assign(connect_str);
+ }
+ }
+
+ if (p != nullptr) {
+ port = strtoul(p + 1, nullptr, 10);
+ }
+ else {
+ /* Default to controller port if not specified */
+ port = DEFAULT_CONTROL_PORT;
+ }
+
+ auto *conn = rspamd_client_init(http_ctx, ev_base, hostbuf.c_str(), port, timeout, pubkey);
+ if (conn != nullptr) {
+ /* Minimal headers; reuse attrs so users can pass Password, etc. */
+ rspamd_client_command(conn,
+ "/plugins/neural/config",
+ attrs,
+ nullptr,
+ rspamc_neural_config_cb,
+ nullptr,
+ compressed,
+ dictionary,
+ "neural_config",
+ &err);
+ }
+}
+
int main(int argc, char **argv, char **env)
{
auto *kwattrs = g_queue_new();
add_options(kwattrs);
auto cmd = maybe_cmd.value();
+ /* Preload neural config once if we are going to use neural_learn */
+ if (cmd.cmd == RSPAMC_COMMAND_NEURAL_LEARN && !neural_cfg_loaded) {
+ rspamc_fetch_neural_config(event_loop, kwattrs);
+ /* Drive loop once to complete config request */
+ ev_loop(event_loop, 0);
+ }
+
if (start_argc == argc && files_list == nullptr) {
/* Do command without input or with stdin */
if (empty_input) {
local function add_cb(err, _)
if err then
rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
- rule.prefix, set.name, profile.redis_key, err)
+ rule.prefix, set.name, profile.redis_key, err)
else
rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
- rule.prefix, set.name, profile.redis_key)
+ rule.prefix, set.name, profile.redis_key)
end
end
lua_redis.redis_make_request(task,
- rule.redis,
- nil,
- true, -- is write
- add_cb, --callback
- 'ZADD', -- command
- { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+ rule.redis,
+ nil,
+ true, -- is write
+ add_cb, --callback
+ 'ZADD', -- command
+ { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
)
return profile
profile = set.ann
else
lua_util.debugm(N, task, 'no ann loaded for %s:%s',
- rule.prefix, set.name)
+ rule.prefix, set.name)
end
else
lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
- rule.prefix, sid)
+ rule.prefix, sid)
end
if ann then
local function after_features(vec, meta)
if profile.providers_digest and meta and meta.digest and profile.providers_digest ~= meta.digest then
lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply',
- rule.prefix, set.name)
+ rule.prefix, set.name)
vec = nil
end
local symscore = string.format('%.3f', score)
task:cache_set(rule.prefix .. '_neural_score', score)
lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
- rule.prefix, set.name, set.ann.version, symscore)
+ rule.prefix, set.name, set.ann.version, symscore)
if score > 0 then
local result = score
end
else
lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
- rule.prefix, set.name, set.ann.version, symscore,
- spam_threshold)
+ rule.prefix, set.name, set.ann.version, symscore,
+ spam_threshold)
end
else
local result = -(score)
end
else
lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
- rule.prefix, set.name, set.ann.version, result,
- ham_threshold)
+ rule.prefix, set.name, set.ann.version, result,
+ ham_threshold)
end
end
end
local train_opts = rule.train
local learn_spam, learn_ham
local skip_reason = 'unknown'
+ local manual_train = false
+
+ -- First, honor explicit manual training header if present
+ do
+ local hdr = task:get_request_header('ANN-Train')
+ if hdr then
+ local hv = tostring(hdr):lower()
+ lua_util.debugm(N, task, 'found ANN-Train header, enable manual train mode', hv)
+ if hv == 'spam' then
+ learn_spam = true
+ manual_train = true
+ elseif hv == 'ham' then
+ learn_ham = true
+ manual_train = true
+ else
+ skip_reason = 'no explicit header'
+ end
+ end
+ end
- if not train_opts.store_pool_only and train_opts.autotrain then
+ if not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then
if train_opts.spam_score then
learn_spam = score >= train_opts.spam_score
if not learn_spam then
skip_reason = string.format('score < spam_score: %f < %f',
- score, train_opts.spam_score)
+ score, train_opts.spam_score)
end
else
learn_spam = verdict == 'spam' or verdict == 'junk'
if not learn_spam then
skip_reason = string.format('verdict: %s',
- verdict)
+ verdict)
end
end
learn_ham = score <= train_opts.ham_score
if not learn_ham then
skip_reason = string.format('score > ham_score: %f > %f',
- score, train_opts.ham_score)
+ score, train_opts.ham_score)
end
else
learn_ham = verdict == 'ham'
if not learn_ham then
skip_reason = string.format('verdict: %s',
- verdict)
+ verdict)
end
end
else
- -- Train by request header
- local hdr = task:get_request_header('ANN-Train')
-
- if hdr then
- if hdr:lower() == 'spam' then
- learn_spam = true
- elseif hdr:lower() == 'ham' then
- learn_ham = true
- else
- skip_reason = 'no explicit header'
- end
- elseif train_opts.store_pool_only then
+ if train_opts.store_pool_only and not manual_train then
local ucl = require "ucl"
learn_ham = false
learn_spam = false
if not err and type(data) == 'table' then
local nspam, nham = data[1], data[2]
- if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
+ if manual_train or neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
local vec
if rule.providers and #rule.providers > 0 then
-- Note: this training path remains sync for now; vectors are pushed when computed
local function learn_vec_cb(redis_err)
if redis_err then
rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
- rule.prefix, set.name, redis_err)
+ rule.prefix, set.name, redis_err)
else
lua_util.debugm(N, task,
- "add train data for ANN rule " ..
- "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
- rule.prefix, set.name, learn_type, #vec, target_key, #str)
+ "add train data for ANN rule " ..
+ "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+ rule.prefix, set.name, learn_type, #vec, target_key, #str)
end
end
lua_redis.redis_make_request(task,
- rule.redis,
- nil,
- true, -- is write
- learn_vec_cb, --callback
- 'SADD', -- command
- { target_key, str } -- arguments
+ rule.redis,
+ nil,
+ true, -- is write
+ learn_vec_cb, --callback
+ 'SADD', -- command
+ { target_key, str } -- arguments
)
else
lua_util.debugm(N, task,
- "do not add %s train data for ANN rule " ..
- "%s:%s",
- learn_type, rule.prefix, set.name)
+ "do not add %s train data for ANN rule " ..
+ "%s:%s",
+ learn_type, rule.prefix, set.name)
end
else
if err then
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
- rule.prefix, set.name, err)
+ rule.prefix, set.name, err)
elseif type(data) == 'string' then
-- nil return value
rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
- learn_type, rule.prefix, set.name, set.ann.redis_key, data)
+ learn_type, rule.prefix, set.name, set.ann.redis_key, data)
else
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
- 'please remove this key from Redis manually if you perform upgrade from the previous version',
- rule.prefix, set.name, set.ann.redis_key, type(data))
+ 'please remove this key from Redis manually if you perform upgrade from the previous version',
+ rule.prefix, set.name, set.ann.redis_key, type(data))
end
end
end
-- Need to create or load a profile corresponding to the current configuration
set.ann = new_ann_profile(task, rule, set, 0)
lua_util.debugm(N, task,
- 'requested new profile for %s, set.ann is missing',
- set.name)
+ 'requested new profile for %s, set.ann is missing',
+ set.name)
end
lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
- { task = task, is_write = false },
- vectors_len_cb,
- {
- set.ann.redis_key,
- })
+ { task = task, is_write = false },
+ vectors_len_cb,
+ {
+ set.ann.redis_key,
+ })
else
lua_util.debugm(N, task,
- 'do not push data: train condition not satisfied; reason: not checked existing ANNs')
+ 'do not push data: train condition not satisfied; reason: not checked existing ANNs')
end
else
lua_util.debugm(N, task,
- 'do not push data to key %s: train condition not satisfied; reason: %s',
- (set.ann or {}).redis_key,
- skip_reason)
+ 'do not push data to key %s: train condition not satisfied; reason: %s',
+ (set.ann or {}).redis_key,
+ skip_reason)
end
end
local function redis_ham_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
- ann_key, err)
+ ann_key, err)
-- Unlock on error
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- neural_common.gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- { ann_key, 'lock' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ { ann_key, 'lock' }
)
else
-- Decompress and convert to numbers each training vector
local function redis_spam_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
- ann_key, err)
+ ann_key, err)
-- Unlock ANN on error
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- neural_common.gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- { ann_key, 'lock' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ { ann_key, 'lock' }
)
else
-- Decompress and convert to numbers each training vector
spam_elts = process_training_vectors(data)
-- Now get ham vectors...
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_ham_cb, --callback
- 'SMEMBERS', -- command
- { ann_key .. '_ham_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_ham_cb, --callback
+ 'SMEMBERS', -- command
+ { ann_key .. '_ham_set' }
)
end
end
local function redis_lock_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
- ann_key, err)
+ ann_key, err)
elseif type(data) == 'number' and data == 1 then
-- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_spam_cb, --callback
- 'SMEMBERS', -- command
- { ann_key .. '_spam_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_spam_cb, --callback
+ 'SMEMBERS', -- command
+ { ann_key .. '_spam_set' }
)
rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
- rule.prefix, set.name, ann_key)
+ rule.prefix, set.name, ann_key)
else
local lock_tm = tonumber(data[1])
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
- 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
- data[2], os.date('%c', lock_tm))
+ 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
+ data[2], os.date('%c', lock_tm))
end
end
-- Check if we are already learning this network
if set.learning_spawned then
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
- ann_key)
+ ann_key)
return
end
-- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
-- ANN is locked by another host (or a process, meh)
lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
- { ev_base = ev_base, is_write = true },
- redis_lock_cb,
- {
- ann_key,
- tostring(os.time()),
- tostring(math.max(10.0, rule.watch_interval * 2)),
- rspamd_util.get_hostname()
- })
+ { ev_base = ev_base, is_write = true },
+ redis_lock_cb,
+ {
+ ann_key,
+ tostring(os.time()),
+ tostring(math.max(10.0, rule.watch_interval * 2)),
+ rspamd_util.get_hostname()
+ })
end
-- This function loads new ann from Redis
local function data_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
- ann_key, err)
+ ann_key, err)
else
if type(data) == 'table' then
if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
if _err or not ann_data then
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
- rule.prefix .. ':' .. set.name, ann_key, _err)
+ rule.prefix .. ':' .. set.name, ann_key, _err)
return
else
ann = rspamd_kann.load(ann_data)
end
-- Also update rank for the loaded ANN to avoid removal
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- rank_cb, --callback
- 'ZADD', -- command
- { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ rank_cb, --callback
+ 'ZADD', -- command
+ { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
)
rspamd_logger.infox(rspamd_config,
- 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
- rule.prefix, set.name, ann_key, #data[1], profile.version)
+ 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
+ rule.prefix, set.name, ann_key, #data[1], profile.version)
else
rspamd_logger.errx(rspamd_config,
- 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
- rule.prefix, set.name, ann_key)
+ 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
+ rule.prefix, set.name, ann_key)
end
end
else
lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
- rule.prefix, set.name, ann_key)
+ rule.prefix, set.name, ann_key)
end
if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
local roc_thresholds = parser:get_object()
set.ann.roc_thresholds = roc_thresholds
rspamd_logger.infox(rspamd_config,
- 'loaded ROC thresholds for %s:%s; version=%s',
- rule.prefix, set.name, profile.version)
+ 'loaded ROC thresholds for %s:%s; version=%s',
+ rule.prefix, set.name, profile.version)
rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
end
end
-- We can use PCA
set.ann.pca = rspamd_tensor.load(pca_data)
rspamd_logger.infox(rspamd_config,
- 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
- rule.prefix, set.name, ann_key, #data[3], profile.version)
+ 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
+ rule.prefix, set.name, ann_key, #data[3], profile.version)
else
-- no need in pca, why is it there?
rspamd_logger.warnx(rspamd_config,
- 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
- rule.prefix, set.name, ann_key)
+ 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
+ rule.prefix, set.name, ann_key)
end
else
-- pca can be missing merely if we have no max_inputs
if rule.max_inputs then
rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
- rule.prefix, set.name, ann_key, _err)
+ rule.prefix, set.name, ann_key, _err)
set.ann.ann = nil
else
-- It is okay
end
else
lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
- rule.prefix, set.name, ann_key)
+ rule.prefix, set.name, ann_key)
end
end
end
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- data_cb, --callback
- 'HMGET', -- command
- { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments
- { opaque_data = true }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ data_cb, --callback
+ 'HMGET', -- command
+ { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments
+ { opaque_data = true }
)
end
if set.ann.version < sel_elt.version then
-- Load new ann
rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
- 'our version = %s, remote version = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.version,
- sel_elt.version)
+ 'our version = %s, remote version = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.version,
+ sel_elt.version)
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
else
lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
- 'our version = %s, remote version = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.version,
- sel_elt.version)
+ 'our version = %s, remote version = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.version,
+ sel_elt.version)
end
else
-- We have some different ANN, so we need to compare distance
if set.ann.distance > min_diff then
-- Load more specific ANN
rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
- 'our distance = %s, remote distance = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.distance,
- min_diff)
+ 'our distance = %s, remote distance = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.distance,
+ min_diff)
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
else
lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
- 'our distance = %s, remote distance = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.distance,
- min_diff)
+ 'our distance = %s, remote distance = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.distance,
+ min_diff)
end
end
else
local ann_key = sel_elt.redis_key
lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
- ann_key)
+ ann_key)
-- Create continuation closure
local redis_len_cb_gen = function(cont_cb, what, is_final)
return function(err, data)
if err then
rspamd_logger.errx(rspamd_config,
- 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
+ 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
elseif data and type(data) == 'number' or type(data) == 'string' then
local ntrains = tonumber(data) or 0
lens[what] = ntrains
end
if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
lua_util.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
cont_cb()
else
lua_util.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
end
else
-- Probabilistic mode, just ensure that at least one vector is okay
if min_len > 0 and max_len >= rule.train.max_trains then
lua_util.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
cont_cb()
else
lua_util.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
end
end
else
lua_util.debugm(N, rspamd_config,
- 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
- what, ann_key, ntrains, rule.train.max_trains)
+ 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
+ what, ann_key, ntrains, rule.train.max_trains)
cont_cb()
end
end
local function initiate_train()
rspamd_logger.infox(rspamd_config,
- 'need to learn ANN %s after %s required learn vectors',
- ann_key, lens)
+ 'need to learn ANN %s after %s required learn vectors',
+ ann_key, lens)
do_train_ann(worker, ev_base, rule, set, ann_key)
end
-- Spam vector is OK, check ham vector length
local function check_ham_len()
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_len_cb_gen(initiate_train, 'ham', true), --callback
- 'SCARD', -- command
- { ann_key .. '_ham_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_len_cb_gen(initiate_train, 'ham', true), --callback
+ 'SCARD', -- command
+ { ann_key .. '_ham_set' }
)
end
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_len_cb_gen(check_ham_len, 'spam', false), --callback
- 'SCARD', -- command
- { ann_key .. '_spam_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_len_cb_gen(check_ham_len, 'spam', false), --callback
+ 'SCARD', -- command
+ { ann_key .. '_spam_set' }
)
end
end
local res, ucl_err = parser:parse_string(element)
if not res then
rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
- ucl_err)
+ ucl_err)
return nil
else
local profile = parser:get_object()
local function members_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
- err)
+ err)
set.can_store_vectors = true
elseif type(data) == 'table' then
lua_util.debugm(N, cfg, '%s: process element %s:%s',
- what, rule.prefix, set.name)
+ what, rule.prefix, set.name)
process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
set.can_store_vectors = true
end
-- Select the most appropriate to our profile but it should not differ by more
-- than 30% of symbols
lua_redis.redis_make_request_taskless(ev_base,
- cfg,
- rule.redis,
- nil,
- false, -- is write
- members_cb, --callback
- 'ZREVRANGE', -- command
- { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
+ cfg,
+ rule.redis,
+ nil,
+ false, -- is write
+ members_cb, --callback
+ 'ZREVRANGE', -- command
+ { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
)
end
end -- Cycle over all settings
local function invalidate_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
- err)
+ err)
elseif type(data) == 'table' then
for _, expired in ipairs(data) do
local profile = load_ann_profile(expired)
rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
- rule.prefix .. ':' .. set.name,
- profile.redis_key,
- profile.version)
+ rule.prefix .. ':' .. set.name,
+ profile.redis_key,
+ profile.version)
end
end
end
if type(set) == 'table' then
lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
- { ev_base = ev_base, is_write = true },
- invalidate_cb,
- { set.prefix, tostring(settings.max_profiles) })
+ { ev_base = ev_base, is_write = true },
+ invalidate_cb,
+ { set.prefix, tostring(settings.max_profiles) })
end
end
end
if verdict == 'passthrough' then
lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
- verdict, score)
+ verdict, score)
return
end
if score ~= score then
lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
- verdict)
+ verdict)
return
end
if rule_elt.max_inputs and not has_blas then
rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in',
- rule_elt.name, rule_elt.max_inputs)
+ rule_elt.name, rule_elt.max_inputs)
rule_elt.max_inputs = nil
end
end
if (pcfg.type == 'llm' or pcfg.name == 'llm') and not (pcfg.model or (rspamd_config:get_all_opt('gpt') or {}).model) then
rspamd_logger.errx(rspamd_config,
- 'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
+ 'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
end
end
end
rspamd_config:add_on_load(function(cfg, ev_base, worker)
if worker:is_scanner() then
rspamd_config:add_periodic(ev_base, 0.0,
- function(_, _)
- return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
- 'try_load_ann')
- end)
+ function(_, _)
+ return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
+ 'try_load_ann')
+ end)
end
if worker:is_primary_controller() then
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
- function(_, _)
- -- Clean old ANNs
- cleanup_anns(rule, cfg, ev_base)
- return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
- 'try_train_ann')
- end)
+ function(_, _)
+ -- Clean old ANNs
+ cleanup_anns(rule, cfg, ev_base)
+ return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
+ 'try_train_ann')
+ end)
end
end)
end