From: Vsevolod Stakhov Date: Mon, 18 Aug 2025 12:42:53 +0000 (+0100) Subject: Neural module rework: provider-based feature fusion, LLM embeddings, normalization... X-Git-Tag: 3.13.0~22^2~8 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=60e1b843b601f747113c551b73e941624b316f67;p=thirdparty%2Frspamd.git Neural module rework: provider-based feature fusion, LLM embeddings, normalization, and v3 schema This PR evolves the neural module from a symbols-only scorer into a general feature-fusion classifier with pluggable providers. It adds an LLM embedding provider, introduces trained normalization and metadata persistence, and isolates new models via a schema/prefix bump. - The existing neural module is limited to metatokens and symbols. - We want to combine multiple feature sources (LLM embeddings now; Bayes/FastText later). - Ensure consistent train/infer behavior with stored normalization and provider metadata. - Improve operability with caching, digest checks, and safer rollouts. - Provider architecture - Provider registry and fusion: `collect_features(task, rule)` concatenates provider vectors with optional weights. - New LLM provider: `lualib/plugins/neural/providers/llm.lua` using `rspamd_http` and `lua_cache` for Redis-backed embedding caching. - Symbols provider extracted: `lualib/plugins/neural/providers/symbols.lua`. - Normalization and PCA - Configurable fusion normalization: none/unit/zscore. - Trained normalization stats computed during training and applied at inference. - Existing global PCA preserved; loaded/saved alongside ANN. - Schema and compatibility - `plugin_ver` bumped to '3' to isolate from earlier profiles. - Redis save/load extended: - Profiles include `providers_digest`. - ANN hash can include `providers_meta`, `norm_stats`, `pca`, `roc_thresholds`, `ann`. - ANN load validates provider digest and skips apply on mismatch. - Performance and reliability - LLM embeddings cached in Redis (content+model keyed). - Graceful fallback to symbols if providers not configured or fail. - Basic provider configuration validation. - `lualib/plugins/neural.lua`: provider registry, fusion, normalization helpers, profile digests, training pipeline updates. - `src/plugins/lua/neural.lua`: integrates fusion into inference/learning, loads new metadata, applies normalization, validates digest. - `lualib/plugins/neural/providers/llm.lua`: LLM embeddings with Redis cache. - `lualib/plugins/neural/providers/symbols.lua`: legacy symbols provider wrapper. - `lualib/redis_scripts/neural_save_unlock.lua`: stores `providers_meta` and `norm_stats` in ANN hash. - `NEURAL_REWORK_PLAN.md`: design and phased TODO. - Enable LLM alongside symbols: ```ucl neural { rules { default { providers = [ { type = "symbols"; weight = 0.5; }, { type = "llm"; model = "text-embed-1"; url = "https://api.openai.com/v1/embeddings"; cache_ttl = 86400; weight = 1.0; } ]; fusion { normalization = "zscore"; } roc_enabled = true; max_inputs = 256; # optional PCA } } } ``` - LLM provider uses `gpt` block for defaults if present (e.g., API key). You can override `model`, `url`, `timeout`, and cache parameters per provider entry. - Existing (v2) neural profiles remain unaffected (new `plugin_ver = '3'` prefixes). - New profiles embed `providers_digest`; incompatible provider sets won’t be applied. - No immediate cleanup required; TTL-based cleanup keeps old keys around until expiry. - Validated: provider digest checks, ANN load/save roundtrip, normalization application at inference, LLM caching paths, symbols fallback. - Please test with/without LLM provider and with `fusion.normalization = none|unit|zscore`. - LLM latency/cost is mitigated by Redis caching; timeouts are configurable per provider. - Privacy: use trusted endpoints; no content leaves unless configured. - Failure behavior: missing/failed providers degrade to others; training/inference can proceed with partial features. - Rules without `providers` continue to use symbols-only behavior. - Existing command surface unchanged; future PR will introduce `rspamc learn_neural:*` and controller endpoints. - [x] Provider registry and fusion - [x] LLM provider with Redis caching - [x] Symbols provider split - [x] Normalization (unit/zscore) with trained stats - [x] Redis schema v3 additions and profile digest - [x] Inference uses trained normalization - [x] Basic provider validation and fallbacks - [x] Plan document - [ ] Per-provider budgets/metrics and circuit breaker for LLM - [ ] Expand providers: Bayes and FastText/subword vectors - [ ] Per-provider PCA and learned fusion - [ ] New CLI (`rspamc learn_neural`) and status/invalidate endpoints - [ ] Documentation expansion under `docs/modules/neural.md` --- diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index 5452146697..b13c6a8273 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -12,7 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -]]-- +]] -- local fun = require "fun" local lua_redis = require "lua_redis" @@ -28,7 +28,7 @@ local ucl = require "ucl" local N = 'neural' -- Used in prefix to avoid wrong ANN to be loaded -local plugin_ver = '2' +local plugin_ver = '3' -- Module vars local default_options = { @@ -43,26 +43,33 @@ local default_options = { learn_threads = 1, learn_mode = 'balanced', -- Possible values: balanced, proportional learning_rate = 0.01, - classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias) - spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1) - ham_skip_prob = 0.0, -- proportional mode: ham skip probability + classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias) + spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1) + ham_skip_prob = 0.0, -- proportional mode: ham skip probability store_pool_only = false, -- store tokens in cache only (disables autotrain); -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest }, watch_interval = 60.0, lock_expire = 600, learning_spawned = false, - ann_expire = 60 * 60 * 24 * 2, -- 2 days - hidden_layer_mult = 1.5, -- number of neurons in the hidden layer - roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds. + ann_expire = 60 * 60 * 24 * 2, -- 2 days + hidden_layer_mult = 1.5, -- number of neurons in the hidden layer + roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds. roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1). - spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable) - ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable) - flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached + spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable) + ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable) + flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached symbol_spam = 'NEURAL_SPAM', symbol_ham = 'NEURAL_HAM', - max_inputs = nil, -- when PCA is used - blacklisted_symbols = {}, -- list of symbols skipped in neural processing + max_inputs = nil, -- when PCA is used + blacklisted_symbols = {}, -- list of symbols skipped in neural processing + -- Phase 0 additions (scaffolding for feature providers) + providers = nil, -- list of provider configs; if nil, fallback to symbols-only provider + fusion = { + normalization = 'none', -- none|unit|zscore (zscore requires stats) + per_provider_pca = false, -- if true, apply PCA per provider before fusion (not active yet) + }, + disable_symbols_input = false, -- when true, do not use symbols provider unless explicitly listed } -- Rule structure: @@ -87,7 +94,7 @@ local default_options = { local settings = { rules = {}, - prefix = 'rn', -- Neural network default prefix + prefix = 'rn', -- Neural network default prefix max_profiles = 3, -- Maximum number of NN profiles stored } @@ -103,15 +110,41 @@ local redis_lua_script_save_unlock = "neural_save_unlock.lua" local redis_script_id = {} +-- Provider registry (Phase 0 scaffolding) +local registered_providers = {} + +--- Registers a feature provider implementation +-- @param name string +-- @param provider table with function collect(task, ctx) -> vector(table of numbers), meta(table) +local function register_provider(name, provider) + registered_providers[name] = provider +end + +local function get_provider(name) + return registered_providers[name] +end + +-- Forward declaration +local result_to_vector + +-- Built-in symbols provider (compatibility path) +register_provider('symbols', { + collect = function(task, ctx) + -- ctx.profile is expected for symbols provider + local vec = result_to_vector(task, ctx.profile) + return vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 } + end +}) + local function load_scripts() redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len, - redis_params) + redis_params) redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate, - redis_params) + redis_params) redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock, - redis_params) + redis_params) redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock, - redis_params) + redis_params) end local function create_ann(n, nlayers, rule) @@ -154,16 +187,99 @@ local function learn_pca(inputs, max_inputs) return w end +-- Build providers metadata for storage alongside ANN +local function build_providers_meta(metas) + if not metas or #metas == 0 then return nil end + local out = {} + for i, m in ipairs(metas) do + out[i] = { + name = m.name, + type = m.type, + dim = m.dim, + weight = m.weight, + model = m.model, + provider = m.provider, + } + end + return out +end + +-- Normalization helpers +local function l2_normalize_vector(vec) + local sumsq = 0.0 + for i = 1, #vec do + local v = vec[i] + sumsq = sumsq + v * v + end + if sumsq > 0 then + local inv = 1.0 / math.sqrt(sumsq) + for i = 1, #vec do + vec[i] = vec[i] * inv + end + end + return vec +end + +local function compute_zscore_stats(inputs) + local n = #inputs + if n == 0 then return nil end + local d = #inputs[1] + local mean = {} + local m2 = {} + for j = 1, d do + mean[j] = 0.0 + m2[j] = 0.0 + end + for i = 1, n do + local x = inputs[i] + for j = 1, d do + local delta = x[j] - mean[j] + mean[j] = mean[j] + delta / i + m2[j] = m2[j] + delta * (x[j] - mean[j]) + end + end + local std = {} + for j = 1, d do + std[j] = math.sqrt((n > 1 and (m2[j] / (n - 1))) or 0.0) + if std[j] == 0 or std[j] ~= std[j] then + std[j] = 1.0 -- avoid division by zero and NaN + end + end + return { mode = 'zscore', mean = mean, std = std } +end + +local function apply_normalization(vec, norm_stats_or_mode) + if not norm_stats_or_mode then return vec end + if type(norm_stats_or_mode) == 'string' then + if norm_stats_or_mode == 'unit' then + return l2_normalize_vector(vec) + else + return vec + end + else + if norm_stats_or_mode.mode == 'unit' then + return l2_normalize_vector(vec) + elseif norm_stats_or_mode.mode == 'zscore' and norm_stats_or_mode.mean and norm_stats_or_mode.std then + local mean = norm_stats_or_mode.mean + local std = norm_stats_or_mode.std + for i = 1, math.min(#vec, #mean) do + vec[i] = (vec[i] - (mean[i] or 0.0)) / (std[i] or 1.0) + end + return vec + else + return vec + end + end +end + -- This function computes optimal threshold using ROC for the given set of inputs. -- Returns a threshold that minimizes: -- alpha * (false_positive_rate) + beta * (false_negative_rate) -- Where alpha is cost of false positive result -- beta is cost of false negative result local function get_roc_thresholds(ann, inputs, outputs, alpha, beta) - -- Sorts list x and list y based on the values in list x. local sort_relative = function(x, y) - local r = {} assert(#x == #y) @@ -219,7 +335,6 @@ local function get_roc_thresholds(ann, inputs, outputs, alpha, beta) spam_count_ahead[n_samples + 1] = 0 for i = n_samples, 1, -1 do - if outputs[i][1] == 0 then n_ham = n_ham + 1 ham_count_ahead[i] = 1 @@ -283,34 +398,34 @@ end -- `set.learning_spawned` is set to `true` local function register_lock_extender(rule, set, ev_base, ann_key) rspamd_config:add_periodic(ev_base, 30.0, - function() - local function redis_lock_extend_cb(err, _) - if err then - rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', - ann_key, err) - else - rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', - ann_key) - end - end - - if set.learning_spawned then - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - redis_lock_extend_cb, --callback - 'HINCRBY', -- command - { ann_key, 'lock', '30' } - ) + function() + local function redis_lock_extend_cb(err, _) + if err then + rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', + ann_key, err) else - lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") - return false -- do not plan any more updates + rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', + ann_key) end + end - return true + if set.learning_spawned then + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + redis_lock_extend_cb, --callback + 'HINCRBY', -- command + { ann_key, 'lock', '30' } + ) + else + lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") + return false -- do not plan any more updates end + + return true + end ) end @@ -332,10 +447,10 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham) local skip_rate = 1.0 - nham / (nspam + 1) if coin < skip_rate - train_opts.classes_bias then rspamd_logger.infox(task, - 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', - learn_type, - skip_rate - train_opts.classes_bias, - nspam, nham) + 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', + learn_type, + skip_rate - train_opts.classes_bias, + nspam, nham) return false end end @@ -343,8 +458,8 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham) else -- Enough learns rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', - learn_type, - nspam) + learn_type, + nspam) end else if nham <= train_opts.max_trains then @@ -353,17 +468,17 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham) local skip_rate = 1.0 - nspam / (nham + 1) if coin < skip_rate - train_opts.classes_bias then rspamd_logger.infox(task, - 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', - learn_type, - skip_rate - train_opts.classes_bias, - nspam, nham) + 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', + learn_type, + skip_rate - train_opts.classes_bias, + nspam, nham) return false end end return true else rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type, - nham) + nham) end end else @@ -374,7 +489,7 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham) if train_opts.spam_skip_prob then if coin <= train_opts.spam_skip_prob then rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type, - coin, train_opts.spam_skip_prob) + coin, train_opts.spam_skip_prob) return false end @@ -382,14 +497,14 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham) end else rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type, - nspam, train_opts.max_trains) + nspam, train_opts.max_trains) end else if nham <= train_opts.max_trains then if train_opts.ham_skip_prob then if coin <= train_opts.ham_skip_prob then rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type, - coin, train_opts.ham_skip_prob) + coin, train_opts.ham_skip_prob) return false end @@ -397,7 +512,7 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham) end else rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type, - nham, train_opts.max_trains) + nham, train_opts.max_trains) end end end @@ -410,10 +525,10 @@ local function gen_unlock_cb(rule, set, ann_key) return function(err) if err then rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s', - rule.prefix, set.name, ann_key, err) + rule.prefix, set.name, ann_key, err) else lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s', - rule.prefix, set.name, ann_key) + rule.prefix, set.name, ann_key) end end end @@ -421,7 +536,7 @@ end -- Used to generate new ANN key for specific profile local function new_ann_key(rule, set, version) local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix, - rule.prefix, set.name, set.digest:sub(1, 8), tostring(version)) + rule.prefix, set.name, set.digest:sub(1, 8), tostring(version)) return ann_key end @@ -430,7 +545,97 @@ local function redis_ann_prefix(rule, settings_name) -- We also need to count metatokens: local n = meta_functions.version return string.format('%s%d_%s_%d_%s', - settings.prefix, plugin_ver, rule.prefix, n, settings_name) + settings.prefix, plugin_ver, rule.prefix, n, settings_name) +end + +-- Compute a stable digest for providers configuration +local function providers_config_digest(providers_cfg) + if not providers_cfg then return nil end + -- Normalize minimal subset of fields to keep digest stable across equivalent configs + local norm = {} + for i, p in ipairs(providers_cfg) do + norm[i] = { + type = p.type, + name = p.name, + weight = p.weight or 1.0, + dim = p.dim, + } + end + return lua_util.table_digest(norm) +end + +-- If no providers configured, fallback to symbols provider unless disabled +-- phase: 'infer' | 'train' +local function collect_features(task, rule, profile_or_set, phase) + local vectors = {} + local metas = {} + + local providers_cfg = rule.providers + if not providers_cfg or #providers_cfg == 0 then + if not rule.disable_symbols_input then + local prov = get_provider('symbols') + if prov then + local vec, meta = prov.collect(task, { profile = profile_or_set, weight = 1.0 }) + if vec then + vectors[#vectors + 1] = vec + metas[#metas + 1] = meta + end + end + end + else + for _, pcfg in ipairs(providers_cfg) do + local prov = get_provider(pcfg.type or pcfg.name) + if prov then + local ok, vec, meta = pcall(function() + return prov.collect(task, { + profile = profile_or_set, + rule = rule, + config = pcfg, + weight = pcfg.weight or 1.0, + phase = phase, + }) + end) + if ok and vec then + if meta then + meta.weight = pcfg.weight or meta.weight or 1.0 + end + vectors[#vectors + 1] = vec + metas[#metas + 1] = meta or + { name = pcfg.name or pcfg.type, type = pcfg.type, dim = #vec, weight = pcfg.weight or 1.0 } + else + rspamd_logger.debugm(N, rspamd_config, 'provider %s failed to collect features', pcfg.type or pcfg.name) + end + else + rspamd_logger.debugm(N, rspamd_config, 'provider %s is not registered', pcfg.type or pcfg.name) + end + end + end + + -- Simple fusion by concatenation; optional per-provider weight scaling + local fused = {} + for i, v in ipairs(vectors) do + local w = (metas[i] and metas[i].weight) or 1.0 + -- Apply normalization if requested + local norm_mode = (rule.fusion and rule.fusion.normalization) or 'none' + if norm_mode ~= 'none' then + v = apply_normalization(v, norm_mode) + end + for _, x in ipairs(v) do + fused[#fused + 1] = x * w + end + end + + local meta = { + providers = build_providers_meta(metas) or metas, + total_dim = #fused, + digest = providers_config_digest(providers_cfg), + } + + if #fused == 0 then + return nil, meta + end + + return fused, meta end -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis @@ -488,71 +693,85 @@ local function spawn_train(params) -- We have nan :( try to log lot's of stuff to dig into a problem seen_nan = true rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s', - params.rule.prefix, params.set.name, - value_cost) + params.rule.prefix, params.set.name, + value_cost) for i, e in ipairs(inputs) do lua_util.debugm(N, rspamd_config, 'train vector %s -> %s', - debug_vec(e), outputs[i][1]) + debug_vec(e), outputs[i][1]) end end rspamd_logger.infox(rspamd_config, - "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s", - params.rule.prefix, params.set.name, - params.ann_key, - iter, - train_cost, - value_cost) + "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s", + params.rule.prefix, params.set.name, + params.ann_key, + iter, + train_cost, + value_cost) end end lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started", - params.rule.prefix, params.set.name) + params.rule.prefix, params.set.name) local pca if params.rule.max_inputs then -- Train PCA in the main process, presumably it is not that long lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s", - params.rule.prefix, params.set.name) + params.rule.prefix, params.set.name) pca = learn_pca(inputs, params.rule.max_inputs) end + -- Compute normalization stats if requested + local norm_stats + if params.rule.fusion and params.rule.fusion.normalization == 'zscore' then + norm_stats = compute_zscore_stats(inputs) + elseif params.rule.fusion and params.rule.fusion.normalization == 'unit' then + norm_stats = { mode = 'unit' } + end + + if norm_stats then + for i = 1, #inputs do + inputs[i] = apply_normalization(inputs[i], norm_stats) + end + end + lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s", - params.rule.prefix, params.set.name) + params.rule.prefix, params.set.name) local ret, err = pcall(train_ann.train1, train_ann, - inputs, outputs, { - lr = params.rule.train.learning_rate, - max_epoch = params.rule.train.max_iterations, - cb = train_cb, - pca = pca - }) + inputs, outputs, { + lr = params.rule.train.learning_rate, + max_epoch = params.rule.train.max_iterations, + cb = train_cb, + pca = pca + }) if not ret then rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s", - params.rule.prefix, params.set.name, err) + params.rule.prefix, params.set.name, err) return nil else lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s", - params.rule.prefix, params.set.name) + params.rule.prefix, params.set.name) end local roc_thresholds = {} if params.rule.roc_enabled then local spam_threshold = get_roc_thresholds(train_ann, - inputs, - outputs, - 1 - params.rule.roc_misclassification_cost, - params.rule.roc_misclassification_cost) + inputs, + outputs, + 1 - params.rule.roc_misclassification_cost, + params.rule.roc_misclassification_cost) local ham_threshold = get_roc_thresholds(train_ann, - inputs, - outputs, - params.rule.roc_misclassification_cost, - 1 - params.rule.roc_misclassification_cost) + inputs, + outputs, + params.rule.roc_misclassification_cost, + 1 - params.rule.roc_misclassification_cost) roc_thresholds = { spam_threshold, ham_threshold } rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)", - roc_thresholds[1], roc_thresholds[2]) + roc_thresholds[1], roc_thresholds[2]) end if not seen_nan then @@ -565,11 +784,12 @@ local function spawn_train(params) ann_data = tostring(train_ann:save()), pca_data = pca_data, roc_thresholds = roc_thresholds, + norm_stats = norm_stats, } local final_data = ucl.to_format(out, 'msgpack') lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes", - params.rule.prefix, params.set.name, #final_data) + params.rule.prefix, params.set.name, #final_data) return final_data else return nil @@ -581,19 +801,19 @@ local function spawn_train(params) local function redis_save_cb(err) if err then rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s', - params.rule.prefix, params.set.name, params.ann_key, err) + params.rule.prefix, params.set.name, params.ann_key, err) lua_redis.redis_make_request_taskless(params.ev_base, - rspamd_config, - params.rule.redis, - nil, - false, -- is write - gen_unlock_cb(params.rule, params.set, params.ann_key), --callback - 'HDEL', -- command - { params.ann_key, 'lock' } + rspamd_config, + params.rule.redis, + nil, + false, -- is write + gen_unlock_cb(params.rule, params.set, params.ann_key), --callback + 'HDEL', -- command + { params.ann_key, 'lock' } ) else rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', - params.rule.prefix, params.set.name, params.set.ann.redis_key) + params.rule.prefix, params.set.name, params.set.ann.redis_key) end end @@ -601,15 +821,15 @@ local function spawn_train(params) params.set.learning_spawned = false if err then rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', - params.rule.prefix, params.set.name, err) + params.rule.prefix, params.set.name, err) lua_redis.redis_make_request_taskless(params.ev_base, - rspamd_config, - params.rule.redis, - nil, - true, -- is write - gen_unlock_cb(params.rule, params.set, params.ann_key), --callback - 'HDEL', -- command - { params.ann_key, 'lock' } + rspamd_config, + params.rule.redis, + nil, + true, -- is write + gen_unlock_cb(params.rule, params.set, params.ann_key), --callback + 'HDEL', -- command + { params.ann_key, 'lock' } ) else local parser = ucl.parser() @@ -619,6 +839,7 @@ local function spawn_train(params) local ann_data = rspamd_util.zstd_compress(parsed.ann_data) local pca_data = parsed.pca_data local roc_thresholds = parsed.roc_thresholds + local norm_stats = parsed.norm_stats fill_set_ann(params.set, params.ann_key) if pca_data then @@ -643,32 +864,40 @@ local function spawn_train(params) symbols = params.set.symbols, digest = params.set.digest, redis_key = params.set.ann.redis_key, - version = version + version = version, + providers_digest = providers_config_digest(params.rule.providers), } local profile_serialized = ucl.to_format(profile, 'json-compact', true) local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true) + local providers_meta_serialized + if params.rule.providers then + providers_meta_serialized = ucl.to_format( + build_providers_meta(params.set.ann.providers or params.rule.providers), 'json-compact', true) + end rspamd_logger.infox(rspamd_config, - 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)', - params.rule.prefix, params.set.name, - #data, #ann_data, - #(params.set.ann.pca or {}), #(pca_data or {}), - params.set.ann.redis_key, params.ann_key) + 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)', + params.rule.prefix, params.set.name, + #data, #ann_data, + #(params.set.ann.pca or {}), #(pca_data or {}), + params.set.ann.redis_key, params.ann_key) lua_redis.exec_redis_script(redis_script_id.save_unlock, - { ev_base = params.ev_base, is_write = true }, - redis_save_cb, - { profile.redis_key, - redis_ann_prefix(params.rule, params.set.name), - ann_data, - profile_serialized, - tostring(params.rule.ann_expire), - tostring(os.time()), - params.ann_key, -- old key to unlock... - roc_thresholds_serialized, - pca_data, - }) + { ev_base = params.ev_base, is_write = true }, + redis_save_cb, + { profile.redis_key, + redis_ann_prefix(params.rule, params.set.name), + ann_data, + profile_serialized, + tostring(params.rule.ann_expire), + tostring(os.time()), + params.ann_key, -- old key to unlock... + roc_thresholds_serialized, + pca_data, + providers_meta_serialized, + ucl.to_format(norm_stats, 'json-compact', true), + }) end end @@ -685,7 +914,6 @@ local function spawn_train(params) params.set.learning_spawned = true register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key) return - end end @@ -698,14 +926,14 @@ local function process_rules_settings() -- Use static user defined profile -- Ensure that we have an array... lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s", - rule.prefix, selt.name, profile) + rule.prefix, selt.name, profile) if not profile[1] then profile = lua_util.keys(profile) end selt.symbols = profile else lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)", - rule.prefix, selt.name) + rule.prefix, selt.name) end local function filter_symbols_predicate(sname) @@ -734,34 +962,34 @@ local function process_rules_settings() selt.prefix = redis_ann_prefix(rule, selt.name) rspamd_logger.messagex(rspamd_config, - 'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"', - selt.prefix, selt.name, selt.digest) + 'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"', + selt.prefix, selt.name, selt.digest) lua_redis.register_prefix(selt.prefix, N, - string.format('NN prefix for rule "%s"; settings id "%s"', - selt.prefix, selt.name), { - persistent = true, - type = 'zlist', - }) + string.format('NN prefix for rule "%s"; settings id "%s"', + selt.prefix, selt.name), { + persistent = true, + type = 'zlist', + }) -- Versions lua_redis.register_prefix(selt.prefix .. '_\\d+', N, - string.format('NN storage for rule "%s"; settings id "%s"', - selt.prefix, selt.name), { - persistent = true, - type = 'hash', - }) + string.format('NN storage for rule "%s"; settings id "%s"', + selt.prefix, selt.name), { + persistent = true, + type = 'hash', + }) lua_redis.register_prefix(selt.prefix .. '_\\d+_spam_set', N, - string.format('NN learning set (spam) for rule "%s"; settings id "%s"', - selt.prefix, selt.name), { - persistent = true, - type = 'set', - }) + string.format('NN learning set (spam) for rule "%s"; settings id "%s"', + selt.prefix, selt.name), { + persistent = true, + type = 'set', + }) lua_redis.register_prefix(selt.prefix .. '_\\d+_ham_set', N, - string.format('NN learning set (ham) for rule "%s"; settings id "%s"', - rule.prefix, selt.name), { - persistent = true, - type = 'set', - }) + string.format('NN learning set (ham) for rule "%s"; settings id "%s"', + rule.prefix, selt.name), { + persistent = true, + type = 'set', + }) end for k, rule in pairs(settings.rules) do @@ -813,8 +1041,8 @@ local function process_rules_settings() if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then -- Equal symbols, add reference lua_util.debugm(N, rspamd_config, - 'added reference from settings id %s to %s; same symbols', - nelt.name, ex.name) + 'added reference from settings id %s to %s; same symbols', + nelt.name, ex.name) rule.settings[settings_id] = id nelt = nil end @@ -824,7 +1052,7 @@ local function process_rules_settings() if nelt then rule.settings[settings_id] = nelt lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s', - nelt.name, settings_id, rule.prefix) + nelt.name, settings_id, rule.prefix) end end end @@ -847,7 +1075,7 @@ local function get_rule_settings(task, rule) return set end -local function result_to_vector(task, profile) +result_to_vector = function(task, profile) if not profile.zeros then -- Fill zeros vector local zeros = {} @@ -874,13 +1102,18 @@ end return { can_push_train_vector = can_push_train_vector, + collect_features = collect_features, create_ann = create_ann, default_options = default_options, + build_providers_meta = build_providers_meta, + apply_normalization = apply_normalization, gen_unlock_cb = gen_unlock_cb, get_rule_settings = get_rule_settings, load_scripts = load_scripts, module_config = module_config, new_ann_key = new_ann_key, + providers_config_digest = providers_config_digest, + register_provider = register_provider, plugin_ver = plugin_ver, process_rules_settings = process_rules_settings, redis_ann_prefix = redis_ann_prefix, diff --git a/lualib/plugins/neural/providers/llm.lua b/lualib/plugins/neural/providers/llm.lua new file mode 100644 index 0000000000..fda0141e33 --- /dev/null +++ b/lualib/plugins/neural/providers/llm.lua @@ -0,0 +1,206 @@ +--[[ +LLM provider for neural feature fusion +Collects text from the most relevant part and requests embeddings from an LLM API. +Supports minimal OpenAI- and Ollama-compatible embedding endpoints. +]] -- + +local rspamd_http = require "rspamd_http" +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" +local lua_mime = require "lua_mime" +local neural_common = require "plugins/neural" +local lua_cache = require "lua_cache" + +local N = "neural.llm" + +local function select_text(task, cfg) + local part = lua_mime.get_displayed_text_part(task) + if part then + local tp = part:get_text() + if tp then + -- Prefer UTF text content + local content = tp:get_content('raw_utf') or tp:get_content('raw') + if content and #content > 0 then + return content + end + end + -- Fallback to raw content + local rc = part:get_raw_content() + if type(rc) == 'userdata' then + rc = tostring(rc) + end + return rc + end + + -- Fallback to subject if no text part + return task:get_subject() or '' +end + +local function compose_llm_settings(pcfg) + local gpt_settings = rspamd_config:get_all_opt('gpt') or {} + local llm_type = pcfg.type or gpt_settings.type or 'openai' + local model = pcfg.model or gpt_settings.model + local timeout = pcfg.timeout or gpt_settings.timeout or 2.0 + local url = pcfg.url + local api_key = pcfg.api_key or gpt_settings.api_key + + if not url then + if llm_type == 'openai' then + url = 'https://api.openai.com/v1/embeddings' + elseif llm_type == 'ollama' then + url = 'http://127.0.0.1:11434/api/embeddings' + end + end + + return { + type = llm_type, + model = model, + timeout = timeout, + url = url, + api_key = api_key, + cache_ttl = pcfg.cache_ttl or 86400, + cache_prefix = pcfg.cache_prefix or 'neural_llm', + cache_hash_len = pcfg.cache_hash_len or 16, + cache_use_hashing = pcfg.cache_use_hashing ~= false, + } +end + +local function extract_embedding(llm_type, parsed) + if llm_type == 'openai' then + -- { data = [ { embedding = [...] } ] } + if parsed and parsed.data and parsed.data[1] and parsed.data[1].embedding then + return parsed.data[1].embedding + end + elseif llm_type == 'ollama' then + -- { embedding = [...] } + if parsed and parsed.embedding then + return parsed.embedding + end + end + return nil +end + +neural_common.register_provider('llm', { + collect = function(task, ctx) + local pcfg = ctx.config or {} + local llm = compose_llm_settings(pcfg) + + if not llm.model then + rspamd_logger.debugm(N, task, 'llm provider missing model; skip') + return nil + end + + local content = select_text(task, pcfg) + if not content or #content == 0 then + rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip') + return nil + end + + local body + if llm.type == 'openai' then + body = { model = llm.model, input = content } + elseif llm.type == 'ollama' then + body = { model = llm.model, prompt = content } + else + rspamd_logger.debugm(N, task, 'unsupported llm type: %s', llm.type) + return nil + end + + -- Redis cache: use content hash + model + provider as key + local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, { + cache_prefix = llm.cache_prefix, + cache_ttl = llm.cache_ttl, + cache_format = 'messagepack', + cache_hash_len = llm.cache_hash_len, + cache_use_hashing = llm.cache_use_hashing, + }, N) + + -- Use a stable key based on content digest + local hasher = require 'rspamd_cryptobox_hash' + local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(content):hex()) + + local function do_request_and_cache() + local headers = { ['Content-Type'] = 'application/json' } + if llm.type == 'openai' and llm.api_key then + headers['Authorization'] = 'Bearer ' .. llm.api_key + end + + local http_params = { + url = llm.url, + mime_type = 'application/json', + timeout = llm.timeout, + log_obj = task, + headers = headers, + body = ucl.to_format(body, 'json-compact', true), + task = task, + method = 'POST', + use_gzip = true, + } + + local err, data = rspamd_http.request(http_params) + if err then + rspamd_logger.debugm(N, task, 'llm request failed: %s', err) + return nil + end + + local parser = ucl.parser() + local ok, perr = parser:parse_string(data.content) + if not ok then + rspamd_logger.debugm(N, task, 'cannot parse llm response: %s', perr) + return nil + end + + local parsed = parser:get_object() + local embedding = extract_embedding(llm.type, parsed) + if not embedding or #embedding == 0 then + rspamd_logger.debugm(N, task, 'no embedding in llm response') + return nil + end + + for i = 1, #embedding do + embedding[i] = tonumber(embedding[i]) or 0.0 + end + + lua_cache.cache_set(task, key, { e = embedding }, cache_ctx) + return embedding + end + + -- Try cache first + local cached_result + local done = false + lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0, + function(_) + -- Uncached: perform request synchronously and store + cached_result = do_request_and_cache() + done = true + end, + function(_, err, data) + if data and data.e then + cached_result = data.e + end + done = true + end + ) + + if not done then + -- Fallback: ensure we still do the request now (cache API is async-ready, but we need sync path) + cached_result = do_request_and_cache() + end + + local embedding = cached_result + if not embedding then + return nil + end + + local meta = { + name = pcfg.name or 'llm', + type = 'llm', + dim = #embedding, + weight = pcfg.weight or 1.0, + model = llm.model, + provider = llm.type, + } + + return embedding, meta + end +}) diff --git a/lualib/plugins/neural/providers/symbols.lua b/lualib/plugins/neural/providers/symbols.lua new file mode 100644 index 0000000000..6a3b750ca8 --- /dev/null +++ b/lualib/plugins/neural/providers/symbols.lua @@ -0,0 +1,10 @@ +-- Symbols provider: wraps legacy symbols+metatokens vectorization + +local neural_common = require "plugins/neural" + +neural_common.register_provider('symbols', { + collect = function(task, ctx) + local vec = neural_common.result_to_vector(task, ctx.profile) + return vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 } + end +}) diff --git a/lualib/redis_scripts/neural_save_unlock.lua b/lualib/redis_scripts/neural_save_unlock.lua index 7ea7dc2e58..dfed2e358f 100644 --- a/lualib/redis_scripts/neural_save_unlock.lua +++ b/lualib/redis_scripts/neural_save_unlock.lua @@ -9,6 +9,8 @@ -- key7 - old key -- key8 - ROC Thresholds -- key9 - optional PCA +-- key10 - optional providers_meta (JSON) +-- key11 - optional norm_stats (JSON) local now = tonumber(KEYS[6]) redis.call('ZADD', KEYS[2], now, KEYS[4]) redis.call('HSET', KEYS[1], 'ann', KEYS[3]) @@ -16,10 +18,16 @@ redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8]) if KEYS[9] then redis.call('HSET', KEYS[1], 'pca', KEYS[9]) end +if KEYS[10] then + redis.call('HSET', KEYS[1], 'providers_meta', KEYS[10]) +end +if KEYS[11] then + redis.call('HSET', KEYS[1], 'norm_stats', KEYS[11]) +end redis.call('HDEL', KEYS[1], 'lock') redis.call('HDEL', KEYS[7], 'lock') redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) - -- expire in 10m, to not face race condition with other rspamd replicas refill deleted keys +-- expire in 10m, to not face race condition with other rspamd replicas refill deleted keys redis.call('EXPIRE', KEYS[7] .. '_spam_set', 600) redis.call('EXPIRE', KEYS[7] .. '_ham_set', 600) return 1 diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index ea40fc4f74..0a8ebcd692 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -12,7 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -]]-- +]] -- if confighelp then @@ -30,6 +30,9 @@ local rspamd_tensor = require "rspamd_tensor" local rspamd_text = require "rspamd_text" local rspamd_util = require "rspamd_util" local ts = require("tableshape").types +-- Load providers +pcall(require, "plugins/neural/providers/llm") +pcall(require, "plugins/neural/providers/symbols") local N = "neural" @@ -41,6 +44,7 @@ local redis_profile_schema = ts.shape { version = ts.number, redis_key = ts.string, distance = ts.number:is_optional(), + providers_digest = ts.string:is_optional(), } local has_blas = rspamd_tensor.has_blas() @@ -55,7 +59,8 @@ local function new_ann_profile(task, rule, set, version) redis_key = ann_key, version = version, digest = set.digest, - distance = 0 -- Since we are using our own profile + distance = 0, -- Since we are using our own profile + providers_digest = neural_common.providers_config_digest(rule.providers), } local ucl = require "ucl" @@ -64,20 +69,20 @@ local function new_ann_profile(task, rule, set, version) local function add_cb(err, _) if err then rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s', - rule.prefix, set.name, profile.redis_key, err) + rule.prefix, set.name, profile.redis_key, err) else rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s', - rule.prefix, set.name, profile.redis_key) + rule.prefix, set.name, profile.redis_key) end end lua_redis.redis_make_request(task, - rule.redis, - nil, - true, -- is write - add_cb, --callback - 'ZADD', -- command - { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } + rule.redis, + nil, + true, -- is write + add_cb, --callback + 'ZADD', -- command + { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } ) return profile @@ -86,7 +91,6 @@ end -- ANN filter function, used to insert scores based on the existing symbols local function ann_scores_filter(task) - for _, rule in pairs(settings.rules) do local sid = task:get_settings_id() or -1 local ann @@ -99,24 +103,41 @@ local function ann_scores_filter(task) profile = set.ann else lua_util.debugm(N, task, 'no ann loaded for %s:%s', - rule.prefix, set.name) + rule.prefix, set.name) end else lua_util.debugm(N, task, 'no ann defined in %s for settings id %s', - rule.prefix, sid) + rule.prefix, sid) end if ann then - local vec = neural_common.result_to_vector(task, profile) + local vec + if rule.providers and #rule.providers > 0 then + local fused, meta = neural_common.collect_features(task, rule, profile) + vec = fused + if profile.providers_digest and meta.digest and profile.providers_digest ~= meta.digest then + lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply', + rule.prefix, set.name) + vec = nil + end + else + vec = neural_common.result_to_vector(task, profile) + end local score + if not vec then + goto continue_rule + end + if set.ann.norm_stats then + vec = neural_common.apply_normalization(vec, set.ann.norm_stats) + end local out = ann:apply1(vec, set.ann.pca) score = out[1] local symscore = string.format('%.3f', score) task:cache_set(rule.prefix .. '_neural_score', score) lua_util.debugm(N, task, '%s:%s:%s ann score: %s', - rule.prefix, set.name, set.ann.version, symscore) + rule.prefix, set.name, set.ann.version, symscore) if score > 0 then local result = score @@ -137,8 +158,8 @@ local function ann_scores_filter(task) end else lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)', - rule.prefix, set.name, set.ann.version, symscore, - spam_threshold) + rule.prefix, set.name, set.ann.version, symscore, + spam_threshold) end else local result = -(score) @@ -159,11 +180,12 @@ local function ann_scores_filter(task) end else lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)', - rule.prefix, set.name, set.ann.version, result, - ham_threshold) + rule.prefix, set.name, set.ann.version, result, + ham_threshold) end end end + ::continue_rule:: end end @@ -178,14 +200,14 @@ local function ann_push_task_result(rule, task, verdict, score, set) if not learn_spam then skip_reason = string.format('score < spam_score: %f < %f', - score, train_opts.spam_score) + score, train_opts.spam_score) end else learn_spam = verdict == 'spam' or verdict == 'junk' if not learn_spam then skip_reason = string.format('verdict: %s', - verdict) + verdict) end end @@ -193,14 +215,14 @@ local function ann_push_task_result(rule, task, verdict, score, set) learn_ham = score <= train_opts.ham_score if not learn_ham then skip_reason = string.format('score > ham_score: %f > %f', - score, train_opts.ham_score) + score, train_opts.ham_score) end else learn_ham = verdict == 'ham' if not learn_ham then skip_reason = string.format('verdict: %s', - verdict) + verdict) end end else @@ -221,7 +243,16 @@ local function ann_push_task_result(rule, task, verdict, score, set) learn_spam = false -- Explicitly store tokens in cache - local vec = neural_common.result_to_vector(task, set) + local vec + if rule.providers and #rule.providers > 0 then + local fused = neural_common.collect_features(task, rule, set, 'train') + if type(fused) == 'table' then + vec = fused + end + end + if not vec then + vec = neural_common.result_to_vector(task, set) + end task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack')) task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest) skip_reason = 'store_pool_only has been set' @@ -241,7 +272,16 @@ local function ann_push_task_result(rule, task, verdict, score, set) local nspam, nham = data[1], data[2] if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then - local vec = neural_common.result_to_vector(task, set) + local vec + if rule.providers and #rule.providers > 0 then + local fused = neural_common.collect_features(task, rule, set) + if type(fused) == 'table' then + vec = fused + end + end + if not vec then + vec = neural_common.result_to_vector(task, set) + end local str = rspamd_util.zstd_compress(table.concat(vec, ';')) local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set' @@ -249,41 +289,41 @@ local function ann_push_task_result(rule, task, verdict, score, set) local function learn_vec_cb(redis_err) if redis_err then rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', - rule.prefix, set.name, redis_err) + rule.prefix, set.name, redis_err) else lua_util.debugm(N, task, - "add train data for ANN rule " .. - "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", - rule.prefix, set.name, learn_type, #vec, target_key, #str) + "add train data for ANN rule " .. + "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", + rule.prefix, set.name, learn_type, #vec, target_key, #str) end end lua_redis.redis_make_request(task, - rule.redis, - nil, - true, -- is write - learn_vec_cb, --callback - 'SADD', -- command - { target_key, str } -- arguments + rule.redis, + nil, + true, -- is write + learn_vec_cb, --callback + 'SADD', -- command + { target_key, str } -- arguments ) else lua_util.debugm(N, task, - "do not add %s train data for ANN rule " .. - "%s:%s", - learn_type, rule.prefix, set.name) + "do not add %s train data for ANN rule " .. + "%s:%s", + learn_type, rule.prefix, set.name) end else if err then rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s', - rule.prefix, set.name, err) + rule.prefix, set.name, err) elseif type(data) == 'string' then -- nil return value rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s", - learn_type, rule.prefix, set.name, set.ann.redis_key, data) + learn_type, rule.prefix, set.name, set.ann.redis_key, data) else rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' .. - 'please remove this key from Redis manually if you perform upgrade from the previous version', - rule.prefix, set.name, set.ann.redis_key, type(data)) + 'please remove this key from Redis manually if you perform upgrade from the previous version', + rule.prefix, set.name, set.ann.redis_key, type(data)) end end end @@ -294,25 +334,25 @@ local function ann_push_task_result(rule, task, verdict, score, set) -- Need to create or load a profile corresponding to the current configuration set.ann = new_ann_profile(task, rule, set, 0) lua_util.debugm(N, task, - 'requested new profile for %s, set.ann is missing', - set.name) + 'requested new profile for %s, set.ann is missing', + set.name) end lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len, - { task = task, is_write = false }, - vectors_len_cb, - { - set.ann.redis_key, - }) + { task = task, is_write = false }, + vectors_len_cb, + { + set.ann.redis_key, + }) else lua_util.debugm(N, task, - 'do not push data: train condition not satisfied; reason: not checked existing ANNs') + 'do not push data: train condition not satisfied; reason: not checked existing ANNs') end else lua_util.debugm(N, task, - 'do not push data to key %s: train condition not satisfied; reason: %s', - (set.ann or {}).redis_key, - skip_reason) + 'do not push data to key %s: train condition not satisfied; reason: %s', + (set.ann or {}).redis_key, + skip_reason) end end @@ -337,23 +377,29 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) local function redis_ham_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', - ann_key, err) + ann_key, err) -- Unlock on error lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - neural_common.gen_unlock_cb(rule, set, ann_key), --callback - 'HDEL', -- command - { ann_key, 'lock' } + rspamd_config, + rule.redis, + nil, + true, -- is write + neural_common.gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + { ann_key, 'lock' } ) else -- Decompress and convert to numbers each training vector ham_elts = process_training_vectors(data) - neural_common.spawn_train({ worker = worker, ev_base = ev_base, - rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts, - spam_vec = spam_elts }) + neural_common.spawn_train({ + worker = worker, + ev_base = ev_base, + rule = rule, + set = set, + ann_key = ann_key, + ham_vec = ham_elts, + spam_vec = spam_elts + }) end end @@ -361,29 +407,29 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) local function redis_spam_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s', - ann_key, err) + ann_key, err) -- Unlock ANN on error lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - neural_common.gen_unlock_cb(rule, set, ann_key), --callback - 'HDEL', -- command - { ann_key, 'lock' } + rspamd_config, + rule.redis, + nil, + true, -- is write + neural_common.gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + { ann_key, 'lock' } ) else -- Decompress and convert to numbers each training vector spam_elts = process_training_vectors(data) -- Now get ham vectors... lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_ham_cb, --callback - 'SMEMBERS', -- command - { ann_key .. '_ham_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_ham_cb, --callback + 'SMEMBERS', -- command + { ann_key .. '_ham_set' } ) end end @@ -391,33 +437,33 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) local function redis_lock_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s', - ann_key, err) + ann_key, err) elseif type(data) == 'number' and data == 1 then -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_spam_cb, --callback - 'SMEMBERS', -- command - { ann_key .. '_spam_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_spam_cb, --callback + 'SMEMBERS', -- command + { ann_key .. '_spam_set' } ) rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning', - rule.prefix, set.name, ann_key) + rule.prefix, set.name, ann_key) else local lock_tm = tonumber(data[1]) rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' .. - 'locked by another host %s at %s', rule.prefix, set.name, ann_key, - data[2], os.date('%c', lock_tm)) + 'locked by another host %s at %s', rule.prefix, set.name, ann_key, + data[2], os.date('%c', lock_tm)) end end -- Check if we are already learning this network if set.learning_spawned then rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', - ann_key) + ann_key) return end @@ -425,14 +471,14 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when -- ANN is locked by another host (or a process, meh) lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock, - { ev_base = ev_base, is_write = true }, - redis_lock_cb, - { - ann_key, - tostring(os.time()), - tostring(math.max(10.0, rule.watch_interval * 2)), - rspamd_util.get_hostname() - }) + { ev_base = ev_base, is_write = true }, + redis_lock_cb, + { + ann_key, + tostring(os.time()), + tostring(math.max(10.0, rule.watch_interval * 2)), + rspamd_util.get_hostname() + }) end -- This function loads new ann from Redis @@ -447,7 +493,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) local function data_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', - ann_key, err) + ann_key, err) else if type(data) == 'table' then if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then @@ -456,7 +502,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) if _err or not ann_data then rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', - rule.prefix .. ':' .. set.name, ann_key, _err) + rule.prefix .. ':' .. set.name, ann_key, _err) return else ann = rspamd_kann.load(ann_data) @@ -467,7 +513,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) version = profile.version, symbols = profile.symbols, distance = min_diff, - redis_key = profile.redis_key + redis_key = profile.redis_key, + providers_digest = profile.providers_digest, } local ucl = require "ucl" @@ -479,26 +526,26 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) end -- Also update rank for the loaded ANN to avoid removal lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - rank_cb, --callback - 'ZADD', -- command - { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } + rspamd_config, + rule.redis, + nil, + true, -- is write + rank_cb, --callback + 'ZADD', -- command + { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } ) rspamd_logger.infox(rspamd_config, - 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', - rule.prefix, set.name, ann_key, #data[1], profile.version) + 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #data[1], profile.version) else rspamd_logger.errx(rspamd_config, - 'cannot unpack/deserialise ANN for %s:%s from Redis key %s', - rule.prefix, set.name, ann_key) + 'cannot unpack/deserialise ANN for %s:%s from Redis key %s', + rule.prefix, set.name, ann_key) end end else lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s', - rule.prefix, set.name, ann_key) + rule.prefix, set.name, ann_key) end if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then @@ -510,8 +557,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) local roc_thresholds = parser:get_object() set.ann.roc_thresholds = roc_thresholds rspamd_logger.infox(rspamd_config, - 'loaded ROC thresholds for %s:%s; version=%s', - rule.prefix, set.name, profile.version) + 'loaded ROC thresholds for %s:%s; version=%s', + rule.prefix, set.name, profile.version) rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds) end end @@ -524,19 +571,19 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) -- We can use PCA set.ann.pca = rspamd_tensor.load(pca_data) rspamd_logger.infox(rspamd_config, - 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s', - rule.prefix, set.name, ann_key, #data[3], profile.version) + 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #data[3], profile.version) else -- no need in pca, why is it there? rspamd_logger.warnx(rspamd_config, - 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined', - rule.prefix, set.name, ann_key) + 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined', + rule.prefix, set.name, ann_key) end else -- pca can be missing merely if we have no max_inputs if rule.max_inputs then rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s', - rule.prefix, set.name, ann_key, _err) + rule.prefix, set.name, ann_key, _err) set.ann.ann = nil else -- It is okay @@ -545,21 +592,39 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) end end + -- Providers meta (optional) + if set.ann and set.ann.ann and type(data[4]) == 'userdata' and data[4].cookie == text_cookie then + local ucl = require "ucl" + local parser = ucl.parser() + local ok = parser:parse_text(data[4]) + if ok then + set.ann.providers_meta = parser:get_object() + end + end + -- Normalization stats (optional) + if set.ann and set.ann.ann and type(data[5]) == 'userdata' and data[5].cookie == text_cookie then + local ucl = require "ucl" + local parser = ucl.parser() + local ok = parser:parse_text(data[5]) + if ok then + set.ann.norm_stats = parser:get_object() + end + end else lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s', - rule.prefix, set.name, ann_key) + rule.prefix, set.name, ann_key) end end end lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - data_cb, --callback - 'HMGET', -- command - { ann_key, 'ann', 'roc_thresholds', 'pca' }, -- arguments - { opaque_data = true } + rspamd_config, + rule.redis, + nil, + false, -- is write + data_cb, --callback + 'HMGET', -- command + { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments + { opaque_data = true } ) end @@ -595,34 +660,34 @@ local function process_existing_ann(_, ev_base, rule, set, profiles) if set.ann.version < sel_elt.version then -- Load new ann rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' .. - 'our version = %s, remote version = %s', - rule.prefix .. ':' .. set.name, - set.ann.version, - sel_elt.version) + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) load_new_ann(rule, ev_base, set, sel_elt, min_diff) else lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' .. - 'our version = %s, remote version = %s', - rule.prefix .. ':' .. set.name, - set.ann.version, - sel_elt.version) + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) end else -- We have some different ANN, so we need to compare distance if set.ann.distance > min_diff then -- Load more specific ANN rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' .. - 'our distance = %s, remote distance = %s', - rule.prefix .. ':' .. set.name, - set.ann.distance, - min_diff) + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) load_new_ann(rule, ev_base, set, sel_elt, min_diff) else lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' .. - 'our distance = %s, remote distance = %s', - rule.prefix .. ':' .. set.name, - set.ann.distance, - min_diff) + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) end end else @@ -660,14 +725,14 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) local ann_key = sel_elt.redis_key lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained", - ann_key) + ann_key) -- Create continuation closure local redis_len_cb_gen = function(cont_cb, what, is_final) return function(err, data) if err then rspamd_logger.errx(rspamd_config, - 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) + 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) elseif data and type(data) == 'number' or type(data) == 'string' then local ntrains = tonumber(data) or 0 lens[what] = ntrains @@ -688,67 +753,65 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) end if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then lua_util.debugm(N, rspamd_config, - 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', - ann_key, lens, rule.train.max_trains, what) + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) cont_cb() else lua_util.debugm(N, rspamd_config, - 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', - ann_key, what, lens, rule.train.max_trains) + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) end else -- Probabilistic mode, just ensure that at least one vector is okay if min_len > 0 and max_len >= rule.train.max_trains then lua_util.debugm(N, rspamd_config, - 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', - ann_key, lens, rule.train.max_trains, what) + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) cont_cb() else lua_util.debugm(N, rspamd_config, - 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', - ann_key, what, lens, rule.train.max_trains) + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) end end - else lua_util.debugm(N, rspamd_config, - 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', - what, ann_key, ntrains, rule.train.max_trains) + 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', + what, ann_key, ntrains, rule.train.max_trains) cont_cb() end end end - end local function initiate_train() rspamd_logger.infox(rspamd_config, - 'need to learn ANN %s after %s required learn vectors', - ann_key, lens) + 'need to learn ANN %s after %s required learn vectors', + ann_key, lens) do_train_ann(worker, ev_base, rule, set, ann_key) end -- Spam vector is OK, check ham vector length local function check_ham_len() lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_len_cb_gen(initiate_train, 'ham', true), --callback - 'SCARD', -- command - { ann_key .. '_ham_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(initiate_train, 'ham', true), --callback + 'SCARD', -- command + { ann_key .. '_ham_set' } ) end lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_len_cb_gen(check_ham_len, 'spam', false), --callback - 'SCARD', -- command - { ann_key .. '_spam_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(check_ham_len, 'spam', false), --callback + 'SCARD', -- command + { ann_key .. '_spam_set' } ) end end @@ -761,7 +824,7 @@ local function load_ann_profile(element) local res, ucl_err = parser:parse_string(element) if not res then rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s', - ucl_err) + ucl_err) return nil else local profile = parser:get_object() @@ -781,11 +844,11 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what) local function members_cb(err, data) if err then rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', - err) + err) set.can_store_vectors = true elseif type(data) == 'table' then lua_util.debugm(N, cfg, '%s: process element %s:%s', - what, rule.prefix, set.name) + what, rule.prefix, set.name) process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data)) set.can_store_vectors = true end @@ -797,13 +860,13 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what) -- Select the most appropriate to our profile but it should not differ by more -- than 30% of symbols lua_redis.redis_make_request_taskless(ev_base, - cfg, - rule.redis, - nil, - false, -- is write - members_cb, --callback - 'ZREVRANGE', -- command - { set.prefix, '0', tostring(settings.max_profiles) } -- arguments + cfg, + rule.redis, + nil, + false, -- is write + members_cb, --callback + 'ZREVRANGE', -- command + { set.prefix, '0', tostring(settings.max_profiles) } -- arguments ) end end -- Cycle over all settings @@ -817,23 +880,23 @@ local function cleanup_anns(rule, cfg, ev_base) local function invalidate_cb(err, data) if err then rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s', - err) + err) elseif type(data) == 'table' then for _, expired in ipairs(data) do local profile = load_ann_profile(expired) rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s', - rule.prefix .. ':' .. set.name, - profile.redis_key, - profile.version) + rule.prefix .. ':' .. set.name, + profile.redis_key, + profile.version) end end end if type(set) == 'table' then lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate, - { ev_base = ev_base, is_write = true }, - invalidate_cb, - { set.prefix, tostring(settings.max_profiles) }) + { ev_base = ev_base, is_write = true }, + invalidate_cb, + { set.prefix, tostring(settings.max_profiles) }) end end end @@ -852,14 +915,14 @@ local function ann_push_vector(task) if verdict == 'passthrough' then lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)', - verdict, score) + verdict, score) return end if score ~= score then lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)', - verdict) + verdict) return end @@ -872,7 +935,6 @@ local function ann_push_vector(task) else lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix) end - end end @@ -930,10 +992,23 @@ for k, r in pairs(rules) do if rule_elt.max_inputs and not has_blas then rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in', - rule_elt.name, rule_elt.max_inputs) + rule_elt.name, rule_elt.max_inputs) rule_elt.max_inputs = nil end + -- Phase 4: basic provider config validation + if rule_elt.providers and #rule_elt.providers > 0 then + for i, pcfg in ipairs(rule_elt.providers) do + if not (pcfg.type or pcfg.name) then + rspamd_logger.errx(rspamd_config, 'provider at index %s in rule %s has no type/name; will be ignored', i, k) + end + if (pcfg.type == 'llm' or pcfg.name == 'llm') and not (pcfg.model or (rspamd_config:get_all_opt('gpt') or {}).model) then + rspamd_logger.errx(rspamd_config, + 'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k) + end + end + end + rspamd_logger.infox(rspamd_config, "register ann rule %s", k) settings.rules[k] = rule_elt rspamd_config:set_metric_symbol({ @@ -980,21 +1055,21 @@ for _, rule in pairs(settings.rules) do rspamd_config:add_on_load(function(cfg, ev_base, worker) if worker:is_scanner() then rspamd_config:add_periodic(ev_base, 0.0, - function(_, _) - return check_anns(worker, cfg, ev_base, rule, process_existing_ann, - 'try_load_ann') - end) + function(_, _) + return check_anns(worker, cfg, ev_base, rule, process_existing_ann, + 'try_load_ann') + end) end if worker:is_primary_controller() then -- We also want to train neural nets when they have enough data rspamd_config:add_periodic(ev_base, 0.0, - function(_, _) - -- Clean old ANNs - cleanup_anns(rule, cfg, ev_base) - return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann, - 'try_train_ann') - end) + function(_, _) + -- Clean old ANNs + cleanup_anns(rule, cfg, ev_base) + return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann, + 'try_train_ann') + end) end end) end