This PR evolves the neural module from a symbols-only scorer into a general feature-fusion classifier with pluggable providers. It adds an LLM embedding provider, introduces trained normalization and metadata persistence, and isolates new models via a schema/prefix bump.
- The existing neural module is limited to metatokens and symbols.
- We want to combine multiple feature sources (LLM embeddings now; Bayes/FastText later).
- Ensure consistent train/infer behavior with stored normalization and provider metadata.
- Improve operability with caching, digest checks, and safer rollouts.
- Provider architecture
- Provider registry and fusion: `collect_features(task, rule)` concatenates provider vectors with optional weights.
- New LLM provider: `lualib/plugins/neural/providers/llm.lua` using `rspamd_http` and `lua_cache` for Redis-backed embedding caching.
- Symbols provider extracted: `lualib/plugins/neural/providers/symbols.lua`.
- Normalization and PCA
- Configurable fusion normalization: none/unit/zscore.
- Trained normalization stats computed during training and applied at inference.
- Existing global PCA preserved; loaded/saved alongside ANN.
- Schema and compatibility
- `plugin_ver` bumped to '3' to isolate from earlier profiles.
- Redis save/load extended:
- Profiles include `providers_digest`.
- ANN hash can include `providers_meta`, `norm_stats`, `pca`, `roc_thresholds`, `ann`.
- ANN load validates provider digest and skips apply on mismatch.
- Performance and reliability
- LLM embeddings cached in Redis (content+model keyed).
- Graceful fallback to symbols if providers not configured or fail.
- Basic provider configuration validation.
- `lualib/plugins/neural.lua`: provider registry, fusion, normalization helpers, profile digests, training pipeline updates.
- `src/plugins/lua/neural.lua`: integrates fusion into inference/learning, loads new metadata, applies normalization, validates digest.
- `lualib/plugins/neural/providers/llm.lua`: LLM embeddings with Redis cache.
- `lualib/plugins/neural/providers/symbols.lua`: legacy symbols provider wrapper.
- `lualib/redis_scripts/neural_save_unlock.lua`: stores `providers_meta` and `norm_stats` in ANN hash.
- `NEURAL_REWORK_PLAN.md`: design and phased TODO.
- Enable LLM alongside symbols:
```ucl
neural {
rules {
default {
providers = [
{ type = "symbols"; weight = 0.5; },
{ type = "llm"; model = "text-embed-1"; url = "https://api.openai.com/v1/embeddings";
cache_ttl = 86400; weight = 1.0; }
];
fusion { normalization = "zscore"; }
roc_enabled = true;
max_inputs = 256; # optional PCA
}
}
}
```
- LLM provider uses `gpt` block for defaults if present (e.g., API key). You can override `model`, `url`, `timeout`, and cache parameters per provider entry.
- Existing (v2) neural profiles remain unaffected (new `plugin_ver = '3'` prefixes).
- New profiles embed `providers_digest`; incompatible provider sets won’t be applied.
- No immediate cleanup required; TTL-based cleanup keeps old keys around until expiry.
- Validated: provider digest checks, ANN load/save roundtrip, normalization application at inference, LLM caching paths, symbols fallback.
- Please test with/without LLM provider and with `fusion.normalization = none|unit|zscore`.
- LLM latency/cost is mitigated by Redis caching; timeouts are configurable per provider.
- Privacy: use trusted endpoints; no content leaves unless configured.
- Failure behavior: missing/failed providers degrade to others; training/inference can proceed with partial features.
- Rules without `providers` continue to use symbols-only behavior.
- Existing command surface unchanged; future PR will introduce `rspamc learn_neural:*` and controller endpoints.
- [x] Provider registry and fusion
- [x] LLM provider with Redis caching
- [x] Symbols provider split
- [x] Normalization (unit/zscore) with trained stats
- [x] Redis schema v3 additions and profile digest
- [x] Inference uses trained normalization
- [x] Basic provider validation and fallbacks
- [x] Plan document
- [ ] Per-provider budgets/metrics and circuit breaker for LLM
- [ ] Expand providers: Bayes and FastText/subword vectors
- [ ] Per-provider PCA and learned fusion
- [ ] New CLI (`rspamc learn_neural`) and status/invalidate endpoints
- [ ] Documentation expansion under `docs/modules/neural.md`
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-]]--
+]] --
local fun = require "fun"
local lua_redis = require "lua_redis"
local N = 'neural'
-- Used in prefix to avoid wrong ANN to be loaded
-local plugin_ver = '2'
+local plugin_ver = '3'
-- Module vars
local default_options = {
learn_threads = 1,
learn_mode = 'balanced', -- Possible values: balanced, proportional
learning_rate = 0.01,
- classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
- spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
- ham_skip_prob = 0.0, -- proportional mode: ham skip probability
+ classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
+ spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
+ ham_skip_prob = 0.0, -- proportional mode: ham skip probability
store_pool_only = false, -- store tokens in cache only (disables autotrain);
-- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
},
watch_interval = 60.0,
lock_expire = 600,
learning_spawned = false,
- ann_expire = 60 * 60 * 24 * 2, -- 2 days
- hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
- roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
+ ann_expire = 60 * 60 * 24 * 2, -- 2 days
+ hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
+ roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
- spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
- ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
- flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
+ spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
+ ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
+ flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
symbol_spam = 'NEURAL_SPAM',
symbol_ham = 'NEURAL_HAM',
- max_inputs = nil, -- when PCA is used
- blacklisted_symbols = {}, -- list of symbols skipped in neural processing
+ max_inputs = nil, -- when PCA is used
+ blacklisted_symbols = {}, -- list of symbols skipped in neural processing
+ -- Phase 0 additions (scaffolding for feature providers)
+ providers = nil, -- list of provider configs; if nil, fallback to symbols-only provider
+ fusion = {
+ normalization = 'none', -- none|unit|zscore (zscore requires stats)
+ per_provider_pca = false, -- if true, apply PCA per provider before fusion (not active yet)
+ },
+ disable_symbols_input = false, -- when true, do not use symbols provider unless explicitly listed
}
-- Rule structure:
local settings = {
rules = {},
- prefix = 'rn', -- Neural network default prefix
+ prefix = 'rn', -- Neural network default prefix
max_profiles = 3, -- Maximum number of NN profiles stored
}
local redis_script_id = {}
+-- Provider registry (Phase 0 scaffolding)
+local registered_providers = {}
+
+--- Registers a feature provider implementation
+-- @param name string
+-- @param provider table with function collect(task, ctx) -> vector(table of numbers), meta(table)
+local function register_provider(name, provider)
+ registered_providers[name] = provider
+end
+
+local function get_provider(name)
+ return registered_providers[name]
+end
+
+-- Forward declaration
+local result_to_vector
+
+-- Built-in symbols provider (compatibility path)
+register_provider('symbols', {
+ collect = function(task, ctx)
+ -- ctx.profile is expected for symbols provider
+ local vec = result_to_vector(task, ctx.profile)
+ return vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 }
+ end
+})
+
local function load_scripts()
redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len,
- redis_params)
+ redis_params)
redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate,
- redis_params)
+ redis_params)
redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock,
- redis_params)
+ redis_params)
redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock,
- redis_params)
+ redis_params)
end
local function create_ann(n, nlayers, rule)
return w
end
+-- Build providers metadata for storage alongside ANN
+local function build_providers_meta(metas)
+ if not metas or #metas == 0 then return nil end
+ local out = {}
+ for i, m in ipairs(metas) do
+ out[i] = {
+ name = m.name,
+ type = m.type,
+ dim = m.dim,
+ weight = m.weight,
+ model = m.model,
+ provider = m.provider,
+ }
+ end
+ return out
+end
+
+-- Normalization helpers
+local function l2_normalize_vector(vec)
+ local sumsq = 0.0
+ for i = 1, #vec do
+ local v = vec[i]
+ sumsq = sumsq + v * v
+ end
+ if sumsq > 0 then
+ local inv = 1.0 / math.sqrt(sumsq)
+ for i = 1, #vec do
+ vec[i] = vec[i] * inv
+ end
+ end
+ return vec
+end
+
+local function compute_zscore_stats(inputs)
+ local n = #inputs
+ if n == 0 then return nil end
+ local d = #inputs[1]
+ local mean = {}
+ local m2 = {}
+ for j = 1, d do
+ mean[j] = 0.0
+ m2[j] = 0.0
+ end
+ for i = 1, n do
+ local x = inputs[i]
+ for j = 1, d do
+ local delta = x[j] - mean[j]
+ mean[j] = mean[j] + delta / i
+ m2[j] = m2[j] + delta * (x[j] - mean[j])
+ end
+ end
+ local std = {}
+ for j = 1, d do
+ std[j] = math.sqrt((n > 1 and (m2[j] / (n - 1))) or 0.0)
+ if std[j] == 0 or std[j] ~= std[j] then
+ std[j] = 1.0 -- avoid division by zero and NaN
+ end
+ end
+ return { mode = 'zscore', mean = mean, std = std }
+end
+
+local function apply_normalization(vec, norm_stats_or_mode)
+ if not norm_stats_or_mode then return vec end
+ if type(norm_stats_or_mode) == 'string' then
+ if norm_stats_or_mode == 'unit' then
+ return l2_normalize_vector(vec)
+ else
+ return vec
+ end
+ else
+ if norm_stats_or_mode.mode == 'unit' then
+ return l2_normalize_vector(vec)
+ elseif norm_stats_or_mode.mode == 'zscore' and norm_stats_or_mode.mean and norm_stats_or_mode.std then
+ local mean = norm_stats_or_mode.mean
+ local std = norm_stats_or_mode.std
+ for i = 1, math.min(#vec, #mean) do
+ vec[i] = (vec[i] - (mean[i] or 0.0)) / (std[i] or 1.0)
+ end
+ return vec
+ else
+ return vec
+ end
+ end
+end
+
-- This function computes optimal threshold using ROC for the given set of inputs.
-- Returns a threshold that minimizes:
-- alpha * (false_positive_rate) + beta * (false_negative_rate)
-- Where alpha is cost of false positive result
-- beta is cost of false negative result
local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
-
-- Sorts list x and list y based on the values in list x.
local sort_relative = function(x, y)
-
local r = {}
assert(#x == #y)
spam_count_ahead[n_samples + 1] = 0
for i = n_samples, 1, -1 do
-
if outputs[i][1] == 0 then
n_ham = n_ham + 1
ham_count_ahead[i] = 1
-- `set.learning_spawned` is set to `true`
local function register_lock_extender(rule, set, ev_base, ann_key)
rspamd_config:add_periodic(ev_base, 30.0,
- function()
- local function redis_lock_extend_cb(err, _)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
- ann_key, err)
- else
- rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
- ann_key)
- end
- end
-
- if set.learning_spawned then
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- redis_lock_extend_cb, --callback
- 'HINCRBY', -- command
- { ann_key, 'lock', '30' }
- )
+ function()
+ local function redis_lock_extend_cb(err, _)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+ ann_key, err)
else
- lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
- return false -- do not plan any more updates
+ rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
+ ann_key)
end
+ end
- return true
+ if set.learning_spawned then
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ redis_lock_extend_cb, --callback
+ 'HINCRBY', -- command
+ { ann_key, 'lock', '30' }
+ )
+ else
+ lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
+ return false -- do not plan any more updates
end
+
+ return true
+ end
)
end
local skip_rate = 1.0 - nham / (nspam + 1)
if coin < skip_rate - train_opts.classes_bias then
rspamd_logger.infox(task,
- 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
- learn_type,
- skip_rate - train_opts.classes_bias,
- nspam, nham)
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
return false
end
end
else
-- Enough learns
rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
- learn_type,
- nspam)
+ learn_type,
+ nspam)
end
else
if nham <= train_opts.max_trains then
local skip_rate = 1.0 - nspam / (nham + 1)
if coin < skip_rate - train_opts.classes_bias then
rspamd_logger.infox(task,
- 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
- learn_type,
- skip_rate - train_opts.classes_bias,
- nspam, nham)
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
return false
end
end
return true
else
rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
- nham)
+ nham)
end
end
else
if train_opts.spam_skip_prob then
if coin <= train_opts.spam_skip_prob then
rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
- coin, train_opts.spam_skip_prob)
+ coin, train_opts.spam_skip_prob)
return false
end
end
else
rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
- nspam, train_opts.max_trains)
+ nspam, train_opts.max_trains)
end
else
if nham <= train_opts.max_trains then
if train_opts.ham_skip_prob then
if coin <= train_opts.ham_skip_prob then
rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
- coin, train_opts.ham_skip_prob)
+ coin, train_opts.ham_skip_prob)
return false
end
end
else
rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
- nham, train_opts.max_trains)
+ nham, train_opts.max_trains)
end
end
end
return function(err)
if err then
rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
- rule.prefix, set.name, ann_key, err)
+ rule.prefix, set.name, ann_key, err)
else
lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
- rule.prefix, set.name, ann_key)
+ rule.prefix, set.name, ann_key)
end
end
end
-- Used to generate new ANN key for specific profile
local function new_ann_key(rule, set, version)
local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
- rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
+ rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
return ann_key
end
-- We also need to count metatokens:
local n = meta_functions.version
return string.format('%s%d_%s_%d_%s',
- settings.prefix, plugin_ver, rule.prefix, n, settings_name)
+ settings.prefix, plugin_ver, rule.prefix, n, settings_name)
+end
+
+-- Compute a stable digest for providers configuration
+local function providers_config_digest(providers_cfg)
+ if not providers_cfg then return nil end
+ -- Normalize minimal subset of fields to keep digest stable across equivalent configs
+ local norm = {}
+ for i, p in ipairs(providers_cfg) do
+ norm[i] = {
+ type = p.type,
+ name = p.name,
+ weight = p.weight or 1.0,
+ dim = p.dim,
+ }
+ end
+ return lua_util.table_digest(norm)
+end
+
+-- If no providers configured, fallback to symbols provider unless disabled
+-- phase: 'infer' | 'train'
+local function collect_features(task, rule, profile_or_set, phase)
+ local vectors = {}
+ local metas = {}
+
+ local providers_cfg = rule.providers
+ if not providers_cfg or #providers_cfg == 0 then
+ if not rule.disable_symbols_input then
+ local prov = get_provider('symbols')
+ if prov then
+ local vec, meta = prov.collect(task, { profile = profile_or_set, weight = 1.0 })
+ if vec then
+ vectors[#vectors + 1] = vec
+ metas[#metas + 1] = meta
+ end
+ end
+ end
+ else
+ for _, pcfg in ipairs(providers_cfg) do
+ local prov = get_provider(pcfg.type or pcfg.name)
+ if prov then
+ local ok, vec, meta = pcall(function()
+ return prov.collect(task, {
+ profile = profile_or_set,
+ rule = rule,
+ config = pcfg,
+ weight = pcfg.weight or 1.0,
+ phase = phase,
+ })
+ end)
+ if ok and vec then
+ if meta then
+ meta.weight = pcfg.weight or meta.weight or 1.0
+ end
+ vectors[#vectors + 1] = vec
+ metas[#metas + 1] = meta or
+ { name = pcfg.name or pcfg.type, type = pcfg.type, dim = #vec, weight = pcfg.weight or 1.0 }
+ else
+ rspamd_logger.debugm(N, rspamd_config, 'provider %s failed to collect features', pcfg.type or pcfg.name)
+ end
+ else
+ rspamd_logger.debugm(N, rspamd_config, 'provider %s is not registered', pcfg.type or pcfg.name)
+ end
+ end
+ end
+
+ -- Simple fusion by concatenation; optional per-provider weight scaling
+ local fused = {}
+ for i, v in ipairs(vectors) do
+ local w = (metas[i] and metas[i].weight) or 1.0
+ -- Apply normalization if requested
+ local norm_mode = (rule.fusion and rule.fusion.normalization) or 'none'
+ if norm_mode ~= 'none' then
+ v = apply_normalization(v, norm_mode)
+ end
+ for _, x in ipairs(v) do
+ fused[#fused + 1] = x * w
+ end
+ end
+
+ local meta = {
+ providers = build_providers_meta(metas) or metas,
+ total_dim = #fused,
+ digest = providers_config_digest(providers_cfg),
+ }
+
+ if #fused == 0 then
+ return nil, meta
+ end
+
+ return fused, meta
end
-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
-- We have nan :( try to log lot's of stuff to dig into a problem
seen_nan = true
rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
- params.rule.prefix, params.set.name,
- value_cost)
+ params.rule.prefix, params.set.name,
+ value_cost)
for i, e in ipairs(inputs) do
lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
- debug_vec(e), outputs[i][1])
+ debug_vec(e), outputs[i][1])
end
end
rspamd_logger.infox(rspamd_config,
- "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
- params.rule.prefix, params.set.name,
- params.ann_key,
- iter,
- train_cost,
- value_cost)
+ "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
+ params.rule.prefix, params.set.name,
+ params.ann_key,
+ iter,
+ train_cost,
+ value_cost)
end
end
lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
- params.rule.prefix, params.set.name)
+ params.rule.prefix, params.set.name)
local pca
if params.rule.max_inputs then
-- Train PCA in the main process, presumably it is not that long
lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s",
- params.rule.prefix, params.set.name)
+ params.rule.prefix, params.set.name)
pca = learn_pca(inputs, params.rule.max_inputs)
end
+ -- Compute normalization stats if requested
+ local norm_stats
+ if params.rule.fusion and params.rule.fusion.normalization == 'zscore' then
+ norm_stats = compute_zscore_stats(inputs)
+ elseif params.rule.fusion and params.rule.fusion.normalization == 'unit' then
+ norm_stats = { mode = 'unit' }
+ end
+
+ if norm_stats then
+ for i = 1, #inputs do
+ inputs[i] = apply_normalization(inputs[i], norm_stats)
+ end
+ end
+
lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s",
- params.rule.prefix, params.set.name)
+ params.rule.prefix, params.set.name)
local ret, err = pcall(train_ann.train1, train_ann,
- inputs, outputs, {
- lr = params.rule.train.learning_rate,
- max_epoch = params.rule.train.max_iterations,
- cb = train_cb,
- pca = pca
- })
+ inputs, outputs, {
+ lr = params.rule.train.learning_rate,
+ max_epoch = params.rule.train.max_iterations,
+ cb = train_cb,
+ pca = pca
+ })
if not ret then
rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
- params.rule.prefix, params.set.name, err)
+ params.rule.prefix, params.set.name, err)
return nil
else
lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s",
- params.rule.prefix, params.set.name)
+ params.rule.prefix, params.set.name)
end
local roc_thresholds = {}
if params.rule.roc_enabled then
local spam_threshold = get_roc_thresholds(train_ann,
- inputs,
- outputs,
- 1 - params.rule.roc_misclassification_cost,
- params.rule.roc_misclassification_cost)
+ inputs,
+ outputs,
+ 1 - params.rule.roc_misclassification_cost,
+ params.rule.roc_misclassification_cost)
local ham_threshold = get_roc_thresholds(train_ann,
- inputs,
- outputs,
- params.rule.roc_misclassification_cost,
- 1 - params.rule.roc_misclassification_cost)
+ inputs,
+ outputs,
+ params.rule.roc_misclassification_cost,
+ 1 - params.rule.roc_misclassification_cost)
roc_thresholds = { spam_threshold, ham_threshold }
rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)",
- roc_thresholds[1], roc_thresholds[2])
+ roc_thresholds[1], roc_thresholds[2])
end
if not seen_nan then
ann_data = tostring(train_ann:save()),
pca_data = pca_data,
roc_thresholds = roc_thresholds,
+ norm_stats = norm_stats,
}
local final_data = ucl.to_format(out, 'msgpack')
lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes",
- params.rule.prefix, params.set.name, #final_data)
+ params.rule.prefix, params.set.name, #final_data)
return final_data
else
return nil
local function redis_save_cb(err)
if err then
rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
- params.rule.prefix, params.set.name, params.ann_key, err)
+ params.rule.prefix, params.set.name, params.ann_key, err)
lua_redis.redis_make_request_taskless(params.ev_base,
- rspamd_config,
- params.rule.redis,
- nil,
- false, -- is write
- gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
- 'HDEL', -- command
- { params.ann_key, 'lock' }
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ false, -- is write
+ gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+ 'HDEL', -- command
+ { params.ann_key, 'lock' }
)
else
rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
- params.rule.prefix, params.set.name, params.set.ann.redis_key)
+ params.rule.prefix, params.set.name, params.set.ann.redis_key)
end
end
params.set.learning_spawned = false
if err then
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
- params.rule.prefix, params.set.name, err)
+ params.rule.prefix, params.set.name, err)
lua_redis.redis_make_request_taskless(params.ev_base,
- rspamd_config,
- params.rule.redis,
- nil,
- true, -- is write
- gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
- 'HDEL', -- command
- { params.ann_key, 'lock' }
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ true, -- is write
+ gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+ 'HDEL', -- command
+ { params.ann_key, 'lock' }
)
else
local parser = ucl.parser()
local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
local pca_data = parsed.pca_data
local roc_thresholds = parsed.roc_thresholds
+ local norm_stats = parsed.norm_stats
fill_set_ann(params.set, params.ann_key)
if pca_data then
symbols = params.set.symbols,
digest = params.set.digest,
redis_key = params.set.ann.redis_key,
- version = version
+ version = version,
+ providers_digest = providers_config_digest(params.rule.providers),
}
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true)
+ local providers_meta_serialized
+ if params.rule.providers then
+ providers_meta_serialized = ucl.to_format(
+ build_providers_meta(params.set.ann.providers or params.rule.providers), 'json-compact', true)
+ end
rspamd_logger.infox(rspamd_config,
- 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
- params.rule.prefix, params.set.name,
- #data, #ann_data,
- #(params.set.ann.pca or {}), #(pca_data or {}),
- params.set.ann.redis_key, params.ann_key)
+ 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
+ params.rule.prefix, params.set.name,
+ #data, #ann_data,
+ #(params.set.ann.pca or {}), #(pca_data or {}),
+ params.set.ann.redis_key, params.ann_key)
lua_redis.exec_redis_script(redis_script_id.save_unlock,
- { ev_base = params.ev_base, is_write = true },
- redis_save_cb,
- { profile.redis_key,
- redis_ann_prefix(params.rule, params.set.name),
- ann_data,
- profile_serialized,
- tostring(params.rule.ann_expire),
- tostring(os.time()),
- params.ann_key, -- old key to unlock...
- roc_thresholds_serialized,
- pca_data,
- })
+ { ev_base = params.ev_base, is_write = true },
+ redis_save_cb,
+ { profile.redis_key,
+ redis_ann_prefix(params.rule, params.set.name),
+ ann_data,
+ profile_serialized,
+ tostring(params.rule.ann_expire),
+ tostring(os.time()),
+ params.ann_key, -- old key to unlock...
+ roc_thresholds_serialized,
+ pca_data,
+ providers_meta_serialized,
+ ucl.to_format(norm_stats, 'json-compact', true),
+ })
end
end
params.set.learning_spawned = true
register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key)
return
-
end
end
-- Use static user defined profile
-- Ensure that we have an array...
lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
- rule.prefix, selt.name, profile)
+ rule.prefix, selt.name, profile)
if not profile[1] then
profile = lua_util.keys(profile)
end
selt.symbols = profile
else
lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
- rule.prefix, selt.name)
+ rule.prefix, selt.name)
end
local function filter_symbols_predicate(sname)
selt.prefix = redis_ann_prefix(rule, selt.name)
rspamd_logger.messagex(rspamd_config,
- 'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"',
- selt.prefix, selt.name, selt.digest)
+ 'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"',
+ selt.prefix, selt.name, selt.digest)
lua_redis.register_prefix(selt.prefix, N,
- string.format('NN prefix for rule "%s"; settings id "%s"',
- selt.prefix, selt.name), {
- persistent = true,
- type = 'zlist',
- })
+ string.format('NN prefix for rule "%s"; settings id "%s"',
+ selt.prefix, selt.name), {
+ persistent = true,
+ type = 'zlist',
+ })
-- Versions
lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
- string.format('NN storage for rule "%s"; settings id "%s"',
- selt.prefix, selt.name), {
- persistent = true,
- type = 'hash',
- })
+ string.format('NN storage for rule "%s"; settings id "%s"',
+ selt.prefix, selt.name), {
+ persistent = true,
+ type = 'hash',
+ })
lua_redis.register_prefix(selt.prefix .. '_\\d+_spam_set', N,
- string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
- selt.prefix, selt.name), {
- persistent = true,
- type = 'set',
- })
+ string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+ selt.prefix, selt.name), {
+ persistent = true,
+ type = 'set',
+ })
lua_redis.register_prefix(selt.prefix .. '_\\d+_ham_set', N,
- string.format('NN learning set (ham) for rule "%s"; settings id "%s"',
- rule.prefix, selt.name), {
- persistent = true,
- type = 'set',
- })
+ string.format('NN learning set (ham) for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name), {
+ persistent = true,
+ type = 'set',
+ })
end
for k, rule in pairs(settings.rules) do
if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
-- Equal symbols, add reference
lua_util.debugm(N, rspamd_config,
- 'added reference from settings id %s to %s; same symbols',
- nelt.name, ex.name)
+ 'added reference from settings id %s to %s; same symbols',
+ nelt.name, ex.name)
rule.settings[settings_id] = id
nelt = nil
end
if nelt then
rule.settings[settings_id] = nelt
lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
- nelt.name, settings_id, rule.prefix)
+ nelt.name, settings_id, rule.prefix)
end
end
end
return set
end
-local function result_to_vector(task, profile)
+result_to_vector = function(task, profile)
if not profile.zeros then
-- Fill zeros vector
local zeros = {}
return {
can_push_train_vector = can_push_train_vector,
+ collect_features = collect_features,
create_ann = create_ann,
default_options = default_options,
+ build_providers_meta = build_providers_meta,
+ apply_normalization = apply_normalization,
gen_unlock_cb = gen_unlock_cb,
get_rule_settings = get_rule_settings,
load_scripts = load_scripts,
module_config = module_config,
new_ann_key = new_ann_key,
+ providers_config_digest = providers_config_digest,
+ register_provider = register_provider,
plugin_ver = plugin_ver,
process_rules_settings = process_rules_settings,
redis_ann_prefix = redis_ann_prefix,
--- /dev/null
+--[[
+LLM provider for neural feature fusion
+Collects text from the most relevant part and requests embeddings from an LLM API.
+Supports minimal OpenAI- and Ollama-compatible embedding endpoints.
+]] --
+
+local rspamd_http = require "rspamd_http"
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+local lua_mime = require "lua_mime"
+local neural_common = require "plugins/neural"
+local lua_cache = require "lua_cache"
+
+local N = "neural.llm"
+
+local function select_text(task, cfg)
+ local part = lua_mime.get_displayed_text_part(task)
+ if part then
+ local tp = part:get_text()
+ if tp then
+ -- Prefer UTF text content
+ local content = tp:get_content('raw_utf') or tp:get_content('raw')
+ if content and #content > 0 then
+ return content
+ end
+ end
+ -- Fallback to raw content
+ local rc = part:get_raw_content()
+ if type(rc) == 'userdata' then
+ rc = tostring(rc)
+ end
+ return rc
+ end
+
+ -- Fallback to subject if no text part
+ return task:get_subject() or ''
+end
+
+local function compose_llm_settings(pcfg)
+ local gpt_settings = rspamd_config:get_all_opt('gpt') or {}
+ local llm_type = pcfg.type or gpt_settings.type or 'openai'
+ local model = pcfg.model or gpt_settings.model
+ local timeout = pcfg.timeout or gpt_settings.timeout or 2.0
+ local url = pcfg.url
+ local api_key = pcfg.api_key or gpt_settings.api_key
+
+ if not url then
+ if llm_type == 'openai' then
+ url = 'https://api.openai.com/v1/embeddings'
+ elseif llm_type == 'ollama' then
+ url = 'http://127.0.0.1:11434/api/embeddings'
+ end
+ end
+
+ return {
+ type = llm_type,
+ model = model,
+ timeout = timeout,
+ url = url,
+ api_key = api_key,
+ cache_ttl = pcfg.cache_ttl or 86400,
+ cache_prefix = pcfg.cache_prefix or 'neural_llm',
+ cache_hash_len = pcfg.cache_hash_len or 16,
+ cache_use_hashing = pcfg.cache_use_hashing ~= false,
+ }
+end
+
+local function extract_embedding(llm_type, parsed)
+ if llm_type == 'openai' then
+ -- { data = [ { embedding = [...] } ] }
+ if parsed and parsed.data and parsed.data[1] and parsed.data[1].embedding then
+ return parsed.data[1].embedding
+ end
+ elseif llm_type == 'ollama' then
+ -- { embedding = [...] }
+ if parsed and parsed.embedding then
+ return parsed.embedding
+ end
+ end
+ return nil
+end
+
+neural_common.register_provider('llm', {
+ collect = function(task, ctx)
+ local pcfg = ctx.config or {}
+ local llm = compose_llm_settings(pcfg)
+
+ if not llm.model then
+ rspamd_logger.debugm(N, task, 'llm provider missing model; skip')
+ return nil
+ end
+
+ local content = select_text(task, pcfg)
+ if not content or #content == 0 then
+ rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip')
+ return nil
+ end
+
+ local body
+ if llm.type == 'openai' then
+ body = { model = llm.model, input = content }
+ elseif llm.type == 'ollama' then
+ body = { model = llm.model, prompt = content }
+ else
+ rspamd_logger.debugm(N, task, 'unsupported llm type: %s', llm.type)
+ return nil
+ end
+
+ -- Redis cache: use content hash + model + provider as key
+ local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, {
+ cache_prefix = llm.cache_prefix,
+ cache_ttl = llm.cache_ttl,
+ cache_format = 'messagepack',
+ cache_hash_len = llm.cache_hash_len,
+ cache_use_hashing = llm.cache_use_hashing,
+ }, N)
+
+ -- Use a stable key based on content digest
+ local hasher = require 'rspamd_cryptobox_hash'
+ local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(content):hex())
+
+ local function do_request_and_cache()
+ local headers = { ['Content-Type'] = 'application/json' }
+ if llm.type == 'openai' and llm.api_key then
+ headers['Authorization'] = 'Bearer ' .. llm.api_key
+ end
+
+ local http_params = {
+ url = llm.url,
+ mime_type = 'application/json',
+ timeout = llm.timeout,
+ log_obj = task,
+ headers = headers,
+ body = ucl.to_format(body, 'json-compact', true),
+ task = task,
+ method = 'POST',
+ use_gzip = true,
+ }
+
+ local err, data = rspamd_http.request(http_params)
+ if err then
+ rspamd_logger.debugm(N, task, 'llm request failed: %s', err)
+ return nil
+ end
+
+ local parser = ucl.parser()
+ local ok, perr = parser:parse_string(data.content)
+ if not ok then
+ rspamd_logger.debugm(N, task, 'cannot parse llm response: %s', perr)
+ return nil
+ end
+
+ local parsed = parser:get_object()
+ local embedding = extract_embedding(llm.type, parsed)
+ if not embedding or #embedding == 0 then
+ rspamd_logger.debugm(N, task, 'no embedding in llm response')
+ return nil
+ end
+
+ for i = 1, #embedding do
+ embedding[i] = tonumber(embedding[i]) or 0.0
+ end
+
+ lua_cache.cache_set(task, key, { e = embedding }, cache_ctx)
+ return embedding
+ end
+
+ -- Try cache first
+ local cached_result
+ local done = false
+ lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0,
+ function(_)
+ -- Uncached: perform request synchronously and store
+ cached_result = do_request_and_cache()
+ done = true
+ end,
+ function(_, err, data)
+ if data and data.e then
+ cached_result = data.e
+ end
+ done = true
+ end
+ )
+
+ if not done then
+ -- Fallback: ensure we still do the request now (cache API is async-ready, but we need sync path)
+ cached_result = do_request_and_cache()
+ end
+
+ local embedding = cached_result
+ if not embedding then
+ return nil
+ end
+
+ local meta = {
+ name = pcfg.name or 'llm',
+ type = 'llm',
+ dim = #embedding,
+ weight = pcfg.weight or 1.0,
+ model = llm.model,
+ provider = llm.type,
+ }
+
+ return embedding, meta
+ end
+})
--- /dev/null
+-- Symbols provider: wraps legacy symbols+metatokens vectorization
+
+local neural_common = require "plugins/neural"
+
+neural_common.register_provider('symbols', {
+ collect = function(task, ctx)
+ local vec = neural_common.result_to_vector(task, ctx.profile)
+ return vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 }
+ end
+})
-- key7 - old key
-- key8 - ROC Thresholds
-- key9 - optional PCA
+-- key10 - optional providers_meta (JSON)
+-- key11 - optional norm_stats (JSON)
local now = tonumber(KEYS[6])
redis.call('ZADD', KEYS[2], now, KEYS[4])
redis.call('HSET', KEYS[1], 'ann', KEYS[3])
if KEYS[9] then
redis.call('HSET', KEYS[1], 'pca', KEYS[9])
end
+if KEYS[10] then
+ redis.call('HSET', KEYS[1], 'providers_meta', KEYS[10])
+end
+if KEYS[11] then
+ redis.call('HSET', KEYS[1], 'norm_stats', KEYS[11])
+end
redis.call('HDEL', KEYS[1], 'lock')
redis.call('HDEL', KEYS[7], 'lock')
redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
- -- expire in 10m, to not face race condition with other rspamd replicas refill deleted keys
+-- expire in 10m, to not face race condition with other rspamd replicas refill deleted keys
redis.call('EXPIRE', KEYS[7] .. '_spam_set', 600)
redis.call('EXPIRE', KEYS[7] .. '_ham_set', 600)
return 1
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-]]--
+]] --
if confighelp then
local rspamd_text = require "rspamd_text"
local rspamd_util = require "rspamd_util"
local ts = require("tableshape").types
+-- Load providers
+pcall(require, "plugins/neural/providers/llm")
+pcall(require, "plugins/neural/providers/symbols")
local N = "neural"
version = ts.number,
redis_key = ts.string,
distance = ts.number:is_optional(),
+ providers_digest = ts.string:is_optional(),
}
local has_blas = rspamd_tensor.has_blas()
redis_key = ann_key,
version = version,
digest = set.digest,
- distance = 0 -- Since we are using our own profile
+ distance = 0, -- Since we are using our own profile
+ providers_digest = neural_common.providers_config_digest(rule.providers),
}
local ucl = require "ucl"
local function add_cb(err, _)
if err then
rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
- rule.prefix, set.name, profile.redis_key, err)
+ rule.prefix, set.name, profile.redis_key, err)
else
rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
- rule.prefix, set.name, profile.redis_key)
+ rule.prefix, set.name, profile.redis_key)
end
end
lua_redis.redis_make_request(task,
- rule.redis,
- nil,
- true, -- is write
- add_cb, --callback
- 'ZADD', -- command
- { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+ rule.redis,
+ nil,
+ true, -- is write
+ add_cb, --callback
+ 'ZADD', -- command
+ { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
)
return profile
-- ANN filter function, used to insert scores based on the existing symbols
local function ann_scores_filter(task)
-
for _, rule in pairs(settings.rules) do
local sid = task:get_settings_id() or -1
local ann
profile = set.ann
else
lua_util.debugm(N, task, 'no ann loaded for %s:%s',
- rule.prefix, set.name)
+ rule.prefix, set.name)
end
else
lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
- rule.prefix, sid)
+ rule.prefix, sid)
end
if ann then
- local vec = neural_common.result_to_vector(task, profile)
+ local vec
+ if rule.providers and #rule.providers > 0 then
+ local fused, meta = neural_common.collect_features(task, rule, profile)
+ vec = fused
+ if profile.providers_digest and meta.digest and profile.providers_digest ~= meta.digest then
+ lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply',
+ rule.prefix, set.name)
+ vec = nil
+ end
+ else
+ vec = neural_common.result_to_vector(task, profile)
+ end
local score
+ if not vec then
+ goto continue_rule
+ end
+ if set.ann.norm_stats then
+ vec = neural_common.apply_normalization(vec, set.ann.norm_stats)
+ end
local out = ann:apply1(vec, set.ann.pca)
score = out[1]
local symscore = string.format('%.3f', score)
task:cache_set(rule.prefix .. '_neural_score', score)
lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
- rule.prefix, set.name, set.ann.version, symscore)
+ rule.prefix, set.name, set.ann.version, symscore)
if score > 0 then
local result = score
end
else
lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
- rule.prefix, set.name, set.ann.version, symscore,
- spam_threshold)
+ rule.prefix, set.name, set.ann.version, symscore,
+ spam_threshold)
end
else
local result = -(score)
end
else
lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
- rule.prefix, set.name, set.ann.version, result,
- ham_threshold)
+ rule.prefix, set.name, set.ann.version, result,
+ ham_threshold)
end
end
end
+ ::continue_rule::
end
end
if not learn_spam then
skip_reason = string.format('score < spam_score: %f < %f',
- score, train_opts.spam_score)
+ score, train_opts.spam_score)
end
else
learn_spam = verdict == 'spam' or verdict == 'junk'
if not learn_spam then
skip_reason = string.format('verdict: %s',
- verdict)
+ verdict)
end
end
learn_ham = score <= train_opts.ham_score
if not learn_ham then
skip_reason = string.format('score > ham_score: %f > %f',
- score, train_opts.ham_score)
+ score, train_opts.ham_score)
end
else
learn_ham = verdict == 'ham'
if not learn_ham then
skip_reason = string.format('verdict: %s',
- verdict)
+ verdict)
end
end
else
learn_spam = false
-- Explicitly store tokens in cache
- local vec = neural_common.result_to_vector(task, set)
+ local vec
+ if rule.providers and #rule.providers > 0 then
+ local fused = neural_common.collect_features(task, rule, set, 'train')
+ if type(fused) == 'table' then
+ vec = fused
+ end
+ end
+ if not vec then
+ vec = neural_common.result_to_vector(task, set)
+ end
task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest)
skip_reason = 'store_pool_only has been set'
local nspam, nham = data[1], data[2]
if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
- local vec = neural_common.result_to_vector(task, set)
+ local vec
+ if rule.providers and #rule.providers > 0 then
+ local fused = neural_common.collect_features(task, rule, set)
+ if type(fused) == 'table' then
+ vec = fused
+ end
+ end
+ if not vec then
+ vec = neural_common.result_to_vector(task, set)
+ end
local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
local function learn_vec_cb(redis_err)
if redis_err then
rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
- rule.prefix, set.name, redis_err)
+ rule.prefix, set.name, redis_err)
else
lua_util.debugm(N, task,
- "add train data for ANN rule " ..
- "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
- rule.prefix, set.name, learn_type, #vec, target_key, #str)
+ "add train data for ANN rule " ..
+ "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+ rule.prefix, set.name, learn_type, #vec, target_key, #str)
end
end
lua_redis.redis_make_request(task,
- rule.redis,
- nil,
- true, -- is write
- learn_vec_cb, --callback
- 'SADD', -- command
- { target_key, str } -- arguments
+ rule.redis,
+ nil,
+ true, -- is write
+ learn_vec_cb, --callback
+ 'SADD', -- command
+ { target_key, str } -- arguments
)
else
lua_util.debugm(N, task,
- "do not add %s train data for ANN rule " ..
- "%s:%s",
- learn_type, rule.prefix, set.name)
+ "do not add %s train data for ANN rule " ..
+ "%s:%s",
+ learn_type, rule.prefix, set.name)
end
else
if err then
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
- rule.prefix, set.name, err)
+ rule.prefix, set.name, err)
elseif type(data) == 'string' then
-- nil return value
rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
- learn_type, rule.prefix, set.name, set.ann.redis_key, data)
+ learn_type, rule.prefix, set.name, set.ann.redis_key, data)
else
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
- 'please remove this key from Redis manually if you perform upgrade from the previous version',
- rule.prefix, set.name, set.ann.redis_key, type(data))
+ 'please remove this key from Redis manually if you perform upgrade from the previous version',
+ rule.prefix, set.name, set.ann.redis_key, type(data))
end
end
end
-- Need to create or load a profile corresponding to the current configuration
set.ann = new_ann_profile(task, rule, set, 0)
lua_util.debugm(N, task,
- 'requested new profile for %s, set.ann is missing',
- set.name)
+ 'requested new profile for %s, set.ann is missing',
+ set.name)
end
lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
- { task = task, is_write = false },
- vectors_len_cb,
- {
- set.ann.redis_key,
- })
+ { task = task, is_write = false },
+ vectors_len_cb,
+ {
+ set.ann.redis_key,
+ })
else
lua_util.debugm(N, task,
- 'do not push data: train condition not satisfied; reason: not checked existing ANNs')
+ 'do not push data: train condition not satisfied; reason: not checked existing ANNs')
end
else
lua_util.debugm(N, task,
- 'do not push data to key %s: train condition not satisfied; reason: %s',
- (set.ann or {}).redis_key,
- skip_reason)
+ 'do not push data to key %s: train condition not satisfied; reason: %s',
+ (set.ann or {}).redis_key,
+ skip_reason)
end
end
local function redis_ham_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
- ann_key, err)
+ ann_key, err)
-- Unlock on error
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- neural_common.gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- { ann_key, 'lock' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ { ann_key, 'lock' }
)
else
-- Decompress and convert to numbers each training vector
ham_elts = process_training_vectors(data)
- neural_common.spawn_train({ worker = worker, ev_base = ev_base,
- rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts,
- spam_vec = spam_elts })
+ neural_common.spawn_train({
+ worker = worker,
+ ev_base = ev_base,
+ rule = rule,
+ set = set,
+ ann_key = ann_key,
+ ham_vec = ham_elts,
+ spam_vec = spam_elts
+ })
end
end
local function redis_spam_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
- ann_key, err)
+ ann_key, err)
-- Unlock ANN on error
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- neural_common.gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- { ann_key, 'lock' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ { ann_key, 'lock' }
)
else
-- Decompress and convert to numbers each training vector
spam_elts = process_training_vectors(data)
-- Now get ham vectors...
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_ham_cb, --callback
- 'SMEMBERS', -- command
- { ann_key .. '_ham_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_ham_cb, --callback
+ 'SMEMBERS', -- command
+ { ann_key .. '_ham_set' }
)
end
end
local function redis_lock_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
- ann_key, err)
+ ann_key, err)
elseif type(data) == 'number' and data == 1 then
-- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_spam_cb, --callback
- 'SMEMBERS', -- command
- { ann_key .. '_spam_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_spam_cb, --callback
+ 'SMEMBERS', -- command
+ { ann_key .. '_spam_set' }
)
rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
- rule.prefix, set.name, ann_key)
+ rule.prefix, set.name, ann_key)
else
local lock_tm = tonumber(data[1])
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
- 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
- data[2], os.date('%c', lock_tm))
+ 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
+ data[2], os.date('%c', lock_tm))
end
end
-- Check if we are already learning this network
if set.learning_spawned then
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
- ann_key)
+ ann_key)
return
end
-- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
-- ANN is locked by another host (or a process, meh)
lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
- { ev_base = ev_base, is_write = true },
- redis_lock_cb,
- {
- ann_key,
- tostring(os.time()),
- tostring(math.max(10.0, rule.watch_interval * 2)),
- rspamd_util.get_hostname()
- })
+ { ev_base = ev_base, is_write = true },
+ redis_lock_cb,
+ {
+ ann_key,
+ tostring(os.time()),
+ tostring(math.max(10.0, rule.watch_interval * 2)),
+ rspamd_util.get_hostname()
+ })
end
-- This function loads new ann from Redis
local function data_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
- ann_key, err)
+ ann_key, err)
else
if type(data) == 'table' then
if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
if _err or not ann_data then
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
- rule.prefix .. ':' .. set.name, ann_key, _err)
+ rule.prefix .. ':' .. set.name, ann_key, _err)
return
else
ann = rspamd_kann.load(ann_data)
version = profile.version,
symbols = profile.symbols,
distance = min_diff,
- redis_key = profile.redis_key
+ redis_key = profile.redis_key,
+ providers_digest = profile.providers_digest,
}
local ucl = require "ucl"
end
-- Also update rank for the loaded ANN to avoid removal
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- rank_cb, --callback
- 'ZADD', -- command
- { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ rank_cb, --callback
+ 'ZADD', -- command
+ { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
)
rspamd_logger.infox(rspamd_config,
- 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
- rule.prefix, set.name, ann_key, #data[1], profile.version)
+ 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
+ rule.prefix, set.name, ann_key, #data[1], profile.version)
else
rspamd_logger.errx(rspamd_config,
- 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
- rule.prefix, set.name, ann_key)
+ 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
+ rule.prefix, set.name, ann_key)
end
end
else
lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
- rule.prefix, set.name, ann_key)
+ rule.prefix, set.name, ann_key)
end
if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
local roc_thresholds = parser:get_object()
set.ann.roc_thresholds = roc_thresholds
rspamd_logger.infox(rspamd_config,
- 'loaded ROC thresholds for %s:%s; version=%s',
- rule.prefix, set.name, profile.version)
+ 'loaded ROC thresholds for %s:%s; version=%s',
+ rule.prefix, set.name, profile.version)
rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
end
end
-- We can use PCA
set.ann.pca = rspamd_tensor.load(pca_data)
rspamd_logger.infox(rspamd_config,
- 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
- rule.prefix, set.name, ann_key, #data[3], profile.version)
+ 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
+ rule.prefix, set.name, ann_key, #data[3], profile.version)
else
-- no need in pca, why is it there?
rspamd_logger.warnx(rspamd_config,
- 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
- rule.prefix, set.name, ann_key)
+ 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
+ rule.prefix, set.name, ann_key)
end
else
-- pca can be missing merely if we have no max_inputs
if rule.max_inputs then
rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
- rule.prefix, set.name, ann_key, _err)
+ rule.prefix, set.name, ann_key, _err)
set.ann.ann = nil
else
-- It is okay
end
end
+ -- Providers meta (optional)
+ if set.ann and set.ann.ann and type(data[4]) == 'userdata' and data[4].cookie == text_cookie then
+ local ucl = require "ucl"
+ local parser = ucl.parser()
+ local ok = parser:parse_text(data[4])
+ if ok then
+ set.ann.providers_meta = parser:get_object()
+ end
+ end
+ -- Normalization stats (optional)
+ if set.ann and set.ann.ann and type(data[5]) == 'userdata' and data[5].cookie == text_cookie then
+ local ucl = require "ucl"
+ local parser = ucl.parser()
+ local ok = parser:parse_text(data[5])
+ if ok then
+ set.ann.norm_stats = parser:get_object()
+ end
+ end
else
lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
- rule.prefix, set.name, ann_key)
+ rule.prefix, set.name, ann_key)
end
end
end
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- data_cb, --callback
- 'HMGET', -- command
- { ann_key, 'ann', 'roc_thresholds', 'pca' }, -- arguments
- { opaque_data = true }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ data_cb, --callback
+ 'HMGET', -- command
+ { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments
+ { opaque_data = true }
)
end
if set.ann.version < sel_elt.version then
-- Load new ann
rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
- 'our version = %s, remote version = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.version,
- sel_elt.version)
+ 'our version = %s, remote version = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.version,
+ sel_elt.version)
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
else
lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
- 'our version = %s, remote version = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.version,
- sel_elt.version)
+ 'our version = %s, remote version = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.version,
+ sel_elt.version)
end
else
-- We have some different ANN, so we need to compare distance
if set.ann.distance > min_diff then
-- Load more specific ANN
rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
- 'our distance = %s, remote distance = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.distance,
- min_diff)
+ 'our distance = %s, remote distance = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.distance,
+ min_diff)
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
else
lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
- 'our distance = %s, remote distance = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.distance,
- min_diff)
+ 'our distance = %s, remote distance = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.distance,
+ min_diff)
end
end
else
local ann_key = sel_elt.redis_key
lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
- ann_key)
+ ann_key)
-- Create continuation closure
local redis_len_cb_gen = function(cont_cb, what, is_final)
return function(err, data)
if err then
rspamd_logger.errx(rspamd_config,
- 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
+ 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
elseif data and type(data) == 'number' or type(data) == 'string' then
local ntrains = tonumber(data) or 0
lens[what] = ntrains
end
if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
lua_util.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
cont_cb()
else
lua_util.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
end
else
-- Probabilistic mode, just ensure that at least one vector is okay
if min_len > 0 and max_len >= rule.train.max_trains then
lua_util.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
cont_cb()
else
lua_util.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
end
end
-
else
lua_util.debugm(N, rspamd_config,
- 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
- what, ann_key, ntrains, rule.train.max_trains)
+ 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
+ what, ann_key, ntrains, rule.train.max_trains)
cont_cb()
end
end
end
-
end
local function initiate_train()
rspamd_logger.infox(rspamd_config,
- 'need to learn ANN %s after %s required learn vectors',
- ann_key, lens)
+ 'need to learn ANN %s after %s required learn vectors',
+ ann_key, lens)
do_train_ann(worker, ev_base, rule, set, ann_key)
end
-- Spam vector is OK, check ham vector length
local function check_ham_len()
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_len_cb_gen(initiate_train, 'ham', true), --callback
- 'SCARD', -- command
- { ann_key .. '_ham_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_len_cb_gen(initiate_train, 'ham', true), --callback
+ 'SCARD', -- command
+ { ann_key .. '_ham_set' }
)
end
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_len_cb_gen(check_ham_len, 'spam', false), --callback
- 'SCARD', -- command
- { ann_key .. '_spam_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_len_cb_gen(check_ham_len, 'spam', false), --callback
+ 'SCARD', -- command
+ { ann_key .. '_spam_set' }
)
end
end
local res, ucl_err = parser:parse_string(element)
if not res then
rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
- ucl_err)
+ ucl_err)
return nil
else
local profile = parser:get_object()
local function members_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
- err)
+ err)
set.can_store_vectors = true
elseif type(data) == 'table' then
lua_util.debugm(N, cfg, '%s: process element %s:%s',
- what, rule.prefix, set.name)
+ what, rule.prefix, set.name)
process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
set.can_store_vectors = true
end
-- Select the most appropriate to our profile but it should not differ by more
-- than 30% of symbols
lua_redis.redis_make_request_taskless(ev_base,
- cfg,
- rule.redis,
- nil,
- false, -- is write
- members_cb, --callback
- 'ZREVRANGE', -- command
- { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
+ cfg,
+ rule.redis,
+ nil,
+ false, -- is write
+ members_cb, --callback
+ 'ZREVRANGE', -- command
+ { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
)
end
end -- Cycle over all settings
local function invalidate_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
- err)
+ err)
elseif type(data) == 'table' then
for _, expired in ipairs(data) do
local profile = load_ann_profile(expired)
rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
- rule.prefix .. ':' .. set.name,
- profile.redis_key,
- profile.version)
+ rule.prefix .. ':' .. set.name,
+ profile.redis_key,
+ profile.version)
end
end
end
if type(set) == 'table' then
lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
- { ev_base = ev_base, is_write = true },
- invalidate_cb,
- { set.prefix, tostring(settings.max_profiles) })
+ { ev_base = ev_base, is_write = true },
+ invalidate_cb,
+ { set.prefix, tostring(settings.max_profiles) })
end
end
end
if verdict == 'passthrough' then
lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
- verdict, score)
+ verdict, score)
return
end
if score ~= score then
lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
- verdict)
+ verdict)
return
end
else
lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
end
-
end
end
if rule_elt.max_inputs and not has_blas then
rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in',
- rule_elt.name, rule_elt.max_inputs)
+ rule_elt.name, rule_elt.max_inputs)
rule_elt.max_inputs = nil
end
+ -- Phase 4: basic provider config validation
+ if rule_elt.providers and #rule_elt.providers > 0 then
+ for i, pcfg in ipairs(rule_elt.providers) do
+ if not (pcfg.type or pcfg.name) then
+ rspamd_logger.errx(rspamd_config, 'provider at index %s in rule %s has no type/name; will be ignored', i, k)
+ end
+ if (pcfg.type == 'llm' or pcfg.name == 'llm') and not (pcfg.model or (rspamd_config:get_all_opt('gpt') or {}).model) then
+ rspamd_logger.errx(rspamd_config,
+ 'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
+ end
+ end
+ end
+
rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
settings.rules[k] = rule_elt
rspamd_config:set_metric_symbol({
rspamd_config:add_on_load(function(cfg, ev_base, worker)
if worker:is_scanner() then
rspamd_config:add_periodic(ev_base, 0.0,
- function(_, _)
- return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
- 'try_load_ann')
- end)
+ function(_, _)
+ return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
+ 'try_load_ann')
+ end)
end
if worker:is_primary_controller() then
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
- function(_, _)
- -- Clean old ANNs
- cleanup_anns(rule, cfg, ev_base)
- return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
- 'try_train_ann')
- end)
+ function(_, _)
+ -- Clean old ANNs
+ cleanup_anns(rule, cfg, ev_base)
+ return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
+ 'try_train_ann')
+ end)
end
end)
end