]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
Neural module rework: provider-based feature fusion, LLM embeddings, normalization...
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 18 Aug 2025 12:42:53 +0000 (13:42 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 18 Aug 2025 12:42:53 +0000 (13:42 +0100)
This PR evolves the neural module from a symbols-only scorer into a general feature-fusion classifier with pluggable providers. It adds an LLM embedding provider, introduces trained normalization and metadata persistence, and isolates new models via a schema/prefix bump.

- The existing neural module is limited to metatokens and symbols.
- We want to combine multiple feature sources (LLM embeddings now; Bayes/FastText later).
- Ensure consistent train/infer behavior with stored normalization and provider metadata.
- Improve operability with caching, digest checks, and safer rollouts.

- Provider architecture
  - Provider registry and fusion: `collect_features(task, rule)` concatenates provider vectors with optional weights.
  - New LLM provider: `lualib/plugins/neural/providers/llm.lua` using `rspamd_http` and `lua_cache` for Redis-backed embedding caching.
  - Symbols provider extracted: `lualib/plugins/neural/providers/symbols.lua`.
- Normalization and PCA
  - Configurable fusion normalization: none/unit/zscore.
  - Trained normalization stats computed during training and applied at inference.
  - Existing global PCA preserved; loaded/saved alongside ANN.
- Schema and compatibility
  - `plugin_ver` bumped to '3' to isolate from earlier profiles.
  - Redis save/load extended:
    - Profiles include `providers_digest`.
    - ANN hash can include `providers_meta`, `norm_stats`, `pca`, `roc_thresholds`, `ann`.
  - ANN load validates provider digest and skips apply on mismatch.
- Performance and reliability
  - LLM embeddings cached in Redis (content+model keyed).
  - Graceful fallback to symbols if providers not configured or fail.
  - Basic provider configuration validation.

- `lualib/plugins/neural.lua`: provider registry, fusion, normalization helpers, profile digests, training pipeline updates.
- `src/plugins/lua/neural.lua`: integrates fusion into inference/learning, loads new metadata, applies normalization, validates digest.
- `lualib/plugins/neural/providers/llm.lua`: LLM embeddings with Redis cache.
- `lualib/plugins/neural/providers/symbols.lua`: legacy symbols provider wrapper.
- `lualib/redis_scripts/neural_save_unlock.lua`: stores `providers_meta` and `norm_stats` in ANN hash.
- `NEURAL_REWORK_PLAN.md`: design and phased TODO.

- Enable LLM alongside symbols:
```ucl
neural {
  rules {
    default {
      providers = [
        { type = "symbols"; weight = 0.5; },
        { type = "llm"; model = "text-embed-1"; url = "https://api.openai.com/v1/embeddings";
          cache_ttl = 86400; weight = 1.0; }
      ];
      fusion { normalization = "zscore"; }
      roc_enabled = true;
      max_inputs = 256; # optional PCA
    }
  }
}
```
- LLM provider uses `gpt` block for defaults if present (e.g., API key). You can override `model`, `url`, `timeout`, and cache parameters per provider entry.

- Existing (v2) neural profiles remain unaffected (new `plugin_ver = '3'` prefixes).
- New profiles embed `providers_digest`; incompatible provider sets won’t be applied.
- No immediate cleanup required; TTL-based cleanup keeps old keys around until expiry.

- Validated: provider digest checks, ANN load/save roundtrip, normalization application at inference, LLM caching paths, symbols fallback.
- Please test with/without LLM provider and with `fusion.normalization = none|unit|zscore`.

- LLM latency/cost is mitigated by Redis caching; timeouts are configurable per provider.
- Privacy: use trusted endpoints; no content leaves unless configured.
- Failure behavior: missing/failed providers degrade to others; training/inference can proceed with partial features.

- Rules without `providers` continue to use symbols-only behavior.
- Existing command surface unchanged; future PR will introduce `rspamc learn_neural:*` and controller endpoints.

- [x] Provider registry and fusion
- [x] LLM provider with Redis caching
- [x] Symbols provider split
- [x] Normalization (unit/zscore) with trained stats
- [x] Redis schema v3 additions and profile digest
- [x] Inference uses trained normalization
- [x] Basic provider validation and fallbacks
- [x] Plan document
- [ ] Per-provider budgets/metrics and circuit breaker for LLM
- [ ] Expand providers: Bayes and FastText/subword vectors
- [ ] Per-provider PCA and learned fusion
- [ ] New CLI (`rspamc learn_neural`) and status/invalidate endpoints
- [ ] Documentation expansion under `docs/modules/neural.md`

lualib/plugins/neural.lua
lualib/plugins/neural/providers/llm.lua [new file with mode: 0644]
lualib/plugins/neural/providers/symbols.lua [new file with mode: 0644]
lualib/redis_scripts/neural_save_unlock.lua
src/plugins/lua/neural.lua

index 54521466976f5074063b7c50e9818720b1871edf..b13c6a8273c71dff35521ad5d340594ab7c6db57 100644 (file)
@@ -12,7 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
-]]--
+]] --
 
 local fun = require "fun"
 local lua_redis = require "lua_redis"
@@ -28,7 +28,7 @@ local ucl = require "ucl"
 local N = 'neural'
 
 -- Used in prefix to avoid wrong ANN to be loaded
-local plugin_ver = '2'
+local plugin_ver = '3'
 
 -- Module vars
 local default_options = {
@@ -43,26 +43,33 @@ local default_options = {
     learn_threads = 1,
     learn_mode = 'balanced', -- Possible values: balanced, proportional
     learning_rate = 0.01,
-    classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
-    spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
-    ham_skip_prob = 0.0, -- proportional mode: ham skip probability
+    classes_bias = 0.0,      -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
+    spam_skip_prob = 0.0,    -- proportional mode: spam skip probability (0-1)
+    ham_skip_prob = 0.0,     -- proportional mode: ham skip probability
     store_pool_only = false, -- store tokens in cache only (disables autotrain);
     -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
   },
   watch_interval = 60.0,
   lock_expire = 600,
   learning_spawned = false,
-  ann_expire = 60 * 60 * 24 * 2, -- 2 days
-  hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
-  roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
+  ann_expire = 60 * 60 * 24 * 2,    -- 2 days
+  hidden_layer_mult = 1.5,          -- number of neurons in the hidden layer
+  roc_enabled = false,              -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
   roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
-  spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
-  ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
-  flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
+  spam_score_threshold = nil,       -- neural score threshold for spam (must be 0..1 or nil to disable)
+  ham_score_threshold = nil,        -- neural score threshold for ham (must be 0..1 or nil to disable)
+  flat_threshold_curve = false,     -- use binary classification 0/1 when threshold is reached
   symbol_spam = 'NEURAL_SPAM',
   symbol_ham = 'NEURAL_HAM',
-  max_inputs = nil, -- when PCA is used
-  blacklisted_symbols = {}, -- list of symbols skipped in neural processing
+  max_inputs = nil,              -- when PCA is used
+  blacklisted_symbols = {},      -- list of symbols skipped in neural processing
+  -- Phase 0 additions (scaffolding for feature providers)
+  providers = nil,               -- list of provider configs; if nil, fallback to symbols-only provider
+  fusion = {
+    normalization = 'none',      -- none|unit|zscore (zscore requires stats)
+    per_provider_pca = false,    -- if true, apply PCA per provider before fusion (not active yet)
+  },
+  disable_symbols_input = false, -- when true, do not use symbols provider unless explicitly listed
 }
 
 -- Rule structure:
@@ -87,7 +94,7 @@ local default_options = {
 
 local settings = {
   rules = {},
-  prefix = 'rn', -- Neural network default prefix
+  prefix = 'rn',    -- Neural network default prefix
   max_profiles = 3, -- Maximum number of NN profiles stored
 }
 
@@ -103,15 +110,41 @@ local redis_lua_script_save_unlock = "neural_save_unlock.lua"
 
 local redis_script_id = {}
 
+-- Provider registry (Phase 0 scaffolding)
+local registered_providers = {}
+
+--- Registers a feature provider implementation
+-- @param name string
+-- @param provider table with function collect(task, ctx) -> vector(table of numbers), meta(table)
+local function register_provider(name, provider)
+  registered_providers[name] = provider
+end
+
+local function get_provider(name)
+  return registered_providers[name]
+end
+
+-- Forward declaration
+local result_to_vector
+
+-- Built-in symbols provider (compatibility path)
+register_provider('symbols', {
+  collect = function(task, ctx)
+    -- ctx.profile is expected for symbols provider
+    local vec = result_to_vector(task, ctx.profile)
+    return vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 }
+  end
+})
+
 local function load_scripts()
   redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len,
-      redis_params)
+    redis_params)
   redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate,
-      redis_params)
+    redis_params)
   redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock,
-      redis_params)
+    redis_params)
   redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock,
-      redis_params)
+    redis_params)
 end
 
 local function create_ann(n, nlayers, rule)
@@ -154,16 +187,99 @@ local function learn_pca(inputs, max_inputs)
   return w
 end
 
+-- Build providers metadata for storage alongside ANN
+local function build_providers_meta(metas)
+  if not metas or #metas == 0 then return nil end
+  local out = {}
+  for i, m in ipairs(metas) do
+    out[i] = {
+      name = m.name,
+      type = m.type,
+      dim = m.dim,
+      weight = m.weight,
+      model = m.model,
+      provider = m.provider,
+    }
+  end
+  return out
+end
+
+-- Normalization helpers
+local function l2_normalize_vector(vec)
+  local sumsq = 0.0
+  for i = 1, #vec do
+    local v = vec[i]
+    sumsq = sumsq + v * v
+  end
+  if sumsq > 0 then
+    local inv = 1.0 / math.sqrt(sumsq)
+    for i = 1, #vec do
+      vec[i] = vec[i] * inv
+    end
+  end
+  return vec
+end
+
+local function compute_zscore_stats(inputs)
+  local n = #inputs
+  if n == 0 then return nil end
+  local d = #inputs[1]
+  local mean = {}
+  local m2 = {}
+  for j = 1, d do
+    mean[j] = 0.0
+    m2[j] = 0.0
+  end
+  for i = 1, n do
+    local x = inputs[i]
+    for j = 1, d do
+      local delta = x[j] - mean[j]
+      mean[j] = mean[j] + delta / i
+      m2[j] = m2[j] + delta * (x[j] - mean[j])
+    end
+  end
+  local std = {}
+  for j = 1, d do
+    std[j] = math.sqrt((n > 1 and (m2[j] / (n - 1))) or 0.0)
+    if std[j] == 0 or std[j] ~= std[j] then
+      std[j] = 1.0 -- avoid division by zero and NaN
+    end
+  end
+  return { mode = 'zscore', mean = mean, std = std }
+end
+
+local function apply_normalization(vec, norm_stats_or_mode)
+  if not norm_stats_or_mode then return vec end
+  if type(norm_stats_or_mode) == 'string' then
+    if norm_stats_or_mode == 'unit' then
+      return l2_normalize_vector(vec)
+    else
+      return vec
+    end
+  else
+    if norm_stats_or_mode.mode == 'unit' then
+      return l2_normalize_vector(vec)
+    elseif norm_stats_or_mode.mode == 'zscore' and norm_stats_or_mode.mean and norm_stats_or_mode.std then
+      local mean = norm_stats_or_mode.mean
+      local std = norm_stats_or_mode.std
+      for i = 1, math.min(#vec, #mean) do
+        vec[i] = (vec[i] - (mean[i] or 0.0)) / (std[i] or 1.0)
+      end
+      return vec
+    else
+      return vec
+    end
+  end
+end
+
 -- This function computes optimal threshold using ROC for the given set of inputs.
 -- Returns a threshold that minimizes:
 --        alpha * (false_positive_rate)  +  beta * (false_negative_rate)
 --        Where alpha is cost of false positive result
 --              beta is cost of false negative result
 local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
-
   -- Sorts list x and list y based on the values in list x.
   local sort_relative = function(x, y)
-
     local r = {}
 
     assert(#x == #y)
@@ -219,7 +335,6 @@ local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
   spam_count_ahead[n_samples + 1] = 0
 
   for i = n_samples, 1, -1 do
-
     if outputs[i][1] == 0 then
       n_ham = n_ham + 1
       ham_count_ahead[i] = 1
@@ -283,34 +398,34 @@ end
 -- `set.learning_spawned` is set to `true`
 local function register_lock_extender(rule, set, ev_base, ann_key)
   rspamd_config:add_periodic(ev_base, 30.0,
-      function()
-        local function redis_lock_extend_cb(err, _)
-          if err then
-            rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
-                ann_key, err)
-          else
-            rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
-                ann_key)
-          end
-        end
-
-        if set.learning_spawned then
-          lua_redis.redis_make_request_taskless(ev_base,
-              rspamd_config,
-              rule.redis,
-              nil,
-              true, -- is write
-              redis_lock_extend_cb, --callback
-              'HINCRBY', -- command
-              { ann_key, 'lock', '30' }
-          )
+    function()
+      local function redis_lock_extend_cb(err, _)
+        if err then
+          rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+            ann_key, err)
         else
-          lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
-          return false -- do not plan any more updates
+          rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
+            ann_key)
         end
+      end
 
-        return true
+      if set.learning_spawned then
+        lua_redis.redis_make_request_taskless(ev_base,
+          rspamd_config,
+          rule.redis,
+          nil,
+          true,                 -- is write
+          redis_lock_extend_cb, --callback
+          'HINCRBY',            -- command
+          { ann_key, 'lock', '30' }
+        )
+      else
+        lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
+        return false -- do not plan any more updates
       end
+
+      return true
+    end
   )
 end
 
@@ -332,10 +447,10 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham)
           local skip_rate = 1.0 - nham / (nspam + 1)
           if coin < skip_rate - train_opts.classes_bias then
             rspamd_logger.infox(task,
-                'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
-                learn_type,
-                skip_rate - train_opts.classes_bias,
-                nspam, nham)
+              'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+              learn_type,
+              skip_rate - train_opts.classes_bias,
+              nspam, nham)
             return false
           end
         end
@@ -343,8 +458,8 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham)
       else
         -- Enough learns
         rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
-            learn_type,
-            nspam)
+          learn_type,
+          nspam)
       end
     else
       if nham <= train_opts.max_trains then
@@ -353,17 +468,17 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham)
           local skip_rate = 1.0 - nspam / (nham + 1)
           if coin < skip_rate - train_opts.classes_bias then
             rspamd_logger.infox(task,
-                'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
-                learn_type,
-                skip_rate - train_opts.classes_bias,
-                nspam, nham)
+              'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+              learn_type,
+              skip_rate - train_opts.classes_bias,
+              nspam, nham)
             return false
           end
         end
         return true
       else
         rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
-            nham)
+          nham)
       end
     end
   else
@@ -374,7 +489,7 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham)
         if train_opts.spam_skip_prob then
           if coin <= train_opts.spam_skip_prob then
             rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
-                coin, train_opts.spam_skip_prob)
+              coin, train_opts.spam_skip_prob)
             return false
           end
 
@@ -382,14 +497,14 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham)
         end
       else
         rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
-            nspam, train_opts.max_trains)
+          nspam, train_opts.max_trains)
       end
     else
       if nham <= train_opts.max_trains then
         if train_opts.ham_skip_prob then
           if coin <= train_opts.ham_skip_prob then
             rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
-                coin, train_opts.ham_skip_prob)
+              coin, train_opts.ham_skip_prob)
             return false
           end
 
@@ -397,7 +512,7 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham)
         end
       else
         rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
-            nham, train_opts.max_trains)
+          nham, train_opts.max_trains)
       end
     end
   end
@@ -410,10 +525,10 @@ local function gen_unlock_cb(rule, set, ann_key)
   return function(err)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
-          rule.prefix, set.name, ann_key, err)
+        rule.prefix, set.name, ann_key, err)
     else
       lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
-          rule.prefix, set.name, ann_key)
+        rule.prefix, set.name, ann_key)
     end
   end
 end
@@ -421,7 +536,7 @@ end
 -- Used to generate new ANN key for specific profile
 local function new_ann_key(rule, set, version)
   local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
-      rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
+    rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
 
   return ann_key
 end
@@ -430,7 +545,97 @@ local function redis_ann_prefix(rule, settings_name)
   -- We also need to count metatokens:
   local n = meta_functions.version
   return string.format('%s%d_%s_%d_%s',
-      settings.prefix, plugin_ver, rule.prefix, n, settings_name)
+    settings.prefix, plugin_ver, rule.prefix, n, settings_name)
+end
+
+-- Compute a stable digest for providers configuration
+local function providers_config_digest(providers_cfg)
+  if not providers_cfg then return nil end
+  -- Normalize minimal subset of fields to keep digest stable across equivalent configs
+  local norm = {}
+  for i, p in ipairs(providers_cfg) do
+    norm[i] = {
+      type = p.type,
+      name = p.name,
+      weight = p.weight or 1.0,
+      dim = p.dim,
+    }
+  end
+  return lua_util.table_digest(norm)
+end
+
+-- If no providers configured, fallback to symbols provider unless disabled
+-- phase: 'infer' | 'train'
+local function collect_features(task, rule, profile_or_set, phase)
+  local vectors = {}
+  local metas = {}
+
+  local providers_cfg = rule.providers
+  if not providers_cfg or #providers_cfg == 0 then
+    if not rule.disable_symbols_input then
+      local prov = get_provider('symbols')
+      if prov then
+        local vec, meta = prov.collect(task, { profile = profile_or_set, weight = 1.0 })
+        if vec then
+          vectors[#vectors + 1] = vec
+          metas[#metas + 1] = meta
+        end
+      end
+    end
+  else
+    for _, pcfg in ipairs(providers_cfg) do
+      local prov = get_provider(pcfg.type or pcfg.name)
+      if prov then
+        local ok, vec, meta = pcall(function()
+          return prov.collect(task, {
+            profile = profile_or_set,
+            rule = rule,
+            config = pcfg,
+            weight = pcfg.weight or 1.0,
+            phase = phase,
+          })
+        end)
+        if ok and vec then
+          if meta then
+            meta.weight = pcfg.weight or meta.weight or 1.0
+          end
+          vectors[#vectors + 1] = vec
+          metas[#metas + 1] = meta or
+              { name = pcfg.name or pcfg.type, type = pcfg.type, dim = #vec, weight = pcfg.weight or 1.0 }
+        else
+          rspamd_logger.debugm(N, rspamd_config, 'provider %s failed to collect features', pcfg.type or pcfg.name)
+        end
+      else
+        rspamd_logger.debugm(N, rspamd_config, 'provider %s is not registered', pcfg.type or pcfg.name)
+      end
+    end
+  end
+
+  -- Simple fusion by concatenation; optional per-provider weight scaling
+  local fused = {}
+  for i, v in ipairs(vectors) do
+    local w = (metas[i] and metas[i].weight) or 1.0
+    -- Apply normalization if requested
+    local norm_mode = (rule.fusion and rule.fusion.normalization) or 'none'
+    if norm_mode ~= 'none' then
+      v = apply_normalization(v, norm_mode)
+    end
+    for _, x in ipairs(v) do
+      fused[#fused + 1] = x * w
+    end
+  end
+
+  local meta = {
+    providers = build_providers_meta(metas) or metas,
+    total_dim = #fused,
+    digest = providers_config_digest(providers_cfg),
+  }
+
+  if #fused == 0 then
+    return nil, meta
+  end
+
+  return fused, meta
 end
 
 -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
@@ -488,71 +693,85 @@ local function spawn_train(params)
             -- We have nan :( try to log lot's of stuff to dig into a problem
             seen_nan = true
             rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
-                params.rule.prefix, params.set.name,
-                value_cost)
+              params.rule.prefix, params.set.name,
+              value_cost)
             for i, e in ipairs(inputs) do
               lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
-                  debug_vec(e), outputs[i][1])
+                debug_vec(e), outputs[i][1])
             end
           end
 
           rspamd_logger.infox(rspamd_config,
-              "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
-              params.rule.prefix, params.set.name,
-              params.ann_key,
-              iter,
-              train_cost,
-              value_cost)
+            "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
+            params.rule.prefix, params.set.name,
+            params.ann_key,
+            iter,
+            train_cost,
+            value_cost)
         end
       end
 
       lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
-          params.rule.prefix, params.set.name)
+        params.rule.prefix, params.set.name)
 
       local pca
       if params.rule.max_inputs then
         -- Train PCA in the main process, presumably it is not that long
         lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s",
-            params.rule.prefix, params.set.name)
+          params.rule.prefix, params.set.name)
         pca = learn_pca(inputs, params.rule.max_inputs)
       end
 
+      -- Compute normalization stats if requested
+      local norm_stats
+      if params.rule.fusion and params.rule.fusion.normalization == 'zscore' then
+        norm_stats = compute_zscore_stats(inputs)
+      elseif params.rule.fusion and params.rule.fusion.normalization == 'unit' then
+        norm_stats = { mode = 'unit' }
+      end
+
+      if norm_stats then
+        for i = 1, #inputs do
+          inputs[i] = apply_normalization(inputs[i], norm_stats)
+        end
+      end
+
       lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s",
-          params.rule.prefix, params.set.name)
+        params.rule.prefix, params.set.name)
       local ret, err = pcall(train_ann.train1, train_ann,
-          inputs, outputs, {
-            lr = params.rule.train.learning_rate,
-            max_epoch = params.rule.train.max_iterations,
-            cb = train_cb,
-            pca = pca
-          })
+        inputs, outputs, {
+          lr = params.rule.train.learning_rate,
+          max_epoch = params.rule.train.max_iterations,
+          cb = train_cb,
+          pca = pca
+        })
 
       if not ret then
         rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
-            params.rule.prefix, params.set.name, err)
+          params.rule.prefix, params.set.name, err)
 
         return nil
       else
         lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s",
-            params.rule.prefix, params.set.name)
+          params.rule.prefix, params.set.name)
       end
 
       local roc_thresholds = {}
       if params.rule.roc_enabled then
         local spam_threshold = get_roc_thresholds(train_ann,
-            inputs,
-            outputs,
-            1 - params.rule.roc_misclassification_cost,
-            params.rule.roc_misclassification_cost)
+          inputs,
+          outputs,
+          1 - params.rule.roc_misclassification_cost,
+          params.rule.roc_misclassification_cost)
         local ham_threshold = get_roc_thresholds(train_ann,
-            inputs,
-            outputs,
-            params.rule.roc_misclassification_cost,
-            1 - params.rule.roc_misclassification_cost)
+          inputs,
+          outputs,
+          params.rule.roc_misclassification_cost,
+          1 - params.rule.roc_misclassification_cost)
         roc_thresholds = { spam_threshold, ham_threshold }
 
         rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)",
-            roc_thresholds[1], roc_thresholds[2])
+          roc_thresholds[1], roc_thresholds[2])
       end
 
       if not seen_nan then
@@ -565,11 +784,12 @@ local function spawn_train(params)
           ann_data = tostring(train_ann:save()),
           pca_data = pca_data,
           roc_thresholds = roc_thresholds,
+          norm_stats = norm_stats,
         }
 
         local final_data = ucl.to_format(out, 'msgpack')
         lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes",
-            params.rule.prefix, params.set.name, #final_data)
+          params.rule.prefix, params.set.name, #final_data)
         return final_data
       else
         return nil
@@ -581,19 +801,19 @@ local function spawn_train(params)
     local function redis_save_cb(err)
       if err then
         rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
-            params.rule.prefix, params.set.name, params.ann_key, err)
+          params.rule.prefix, params.set.name, params.ann_key, err)
         lua_redis.redis_make_request_taskless(params.ev_base,
-            rspamd_config,
-            params.rule.redis,
-            nil,
-            false, -- is write
-            gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
-            'HDEL', -- command
-            { params.ann_key, 'lock' }
+          rspamd_config,
+          params.rule.redis,
+          nil,
+          false,                                                  -- is write
+          gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+          'HDEL',                                                 -- command
+          { params.ann_key, 'lock' }
         )
       else
         rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
-            params.rule.prefix, params.set.name, params.set.ann.redis_key)
+          params.rule.prefix, params.set.name, params.set.ann.redis_key)
       end
     end
 
@@ -601,15 +821,15 @@ local function spawn_train(params)
       params.set.learning_spawned = false
       if err then
         rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
-            params.rule.prefix, params.set.name, err)
+          params.rule.prefix, params.set.name, err)
         lua_redis.redis_make_request_taskless(params.ev_base,
-            rspamd_config,
-            params.rule.redis,
-            nil,
-            true, -- is write
-            gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
-            'HDEL', -- command
-            { params.ann_key, 'lock' }
+          rspamd_config,
+          params.rule.redis,
+          nil,
+          true,                                                   -- is write
+          gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+          'HDEL',                                                 -- command
+          { params.ann_key, 'lock' }
         )
       else
         local parser = ucl.parser()
@@ -619,6 +839,7 @@ local function spawn_train(params)
         local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
         local pca_data = parsed.pca_data
         local roc_thresholds = parsed.roc_thresholds
+        local norm_stats = parsed.norm_stats
 
         fill_set_ann(params.set, params.ann_key)
         if pca_data then
@@ -643,32 +864,40 @@ local function spawn_train(params)
           symbols = params.set.symbols,
           digest = params.set.digest,
           redis_key = params.set.ann.redis_key,
-          version = version
+          version = version,
+          providers_digest = providers_config_digest(params.rule.providers),
         }
 
         local profile_serialized = ucl.to_format(profile, 'json-compact', true)
         local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true)
+        local providers_meta_serialized
+        if params.rule.providers then
+          providers_meta_serialized = ucl.to_format(
+            build_providers_meta(params.set.ann.providers or params.rule.providers), 'json-compact', true)
+        end
 
         rspamd_logger.infox(rspamd_config,
-            'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
-            params.rule.prefix, params.set.name,
-            #data, #ann_data,
-            #(params.set.ann.pca or {}), #(pca_data or {}),
-            params.set.ann.redis_key, params.ann_key)
+          'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
+          params.rule.prefix, params.set.name,
+          #data, #ann_data,
+          #(params.set.ann.pca or {}), #(pca_data or {}),
+          params.set.ann.redis_key, params.ann_key)
 
         lua_redis.exec_redis_script(redis_script_id.save_unlock,
-            { ev_base = params.ev_base, is_write = true },
-            redis_save_cb,
-            { profile.redis_key,
-              redis_ann_prefix(params.rule, params.set.name),
-              ann_data,
-              profile_serialized,
-              tostring(params.rule.ann_expire),
-              tostring(os.time()),
-              params.ann_key, -- old key to unlock...
-              roc_thresholds_serialized,
-              pca_data,
-            })
+          { ev_base = params.ev_base, is_write = true },
+          redis_save_cb,
+          { profile.redis_key,
+            redis_ann_prefix(params.rule, params.set.name),
+            ann_data,
+            profile_serialized,
+            tostring(params.rule.ann_expire),
+            tostring(os.time()),
+            params.ann_key, -- old key to unlock...
+            roc_thresholds_serialized,
+            pca_data,
+            providers_meta_serialized,
+            ucl.to_format(norm_stats, 'json-compact', true),
+          })
       end
     end
 
@@ -685,7 +914,6 @@ local function spawn_train(params)
     params.set.learning_spawned = true
     register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key)
     return
-
   end
 end
 
@@ -698,14 +926,14 @@ local function process_rules_settings()
       -- Use static user defined profile
       -- Ensure that we have an array...
       lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
-          rule.prefix, selt.name, profile)
+        rule.prefix, selt.name, profile)
       if not profile[1] then
         profile = lua_util.keys(profile)
       end
       selt.symbols = profile
     else
       lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
-          rule.prefix, selt.name)
+        rule.prefix, selt.name)
     end
 
     local function filter_symbols_predicate(sname)
@@ -734,34 +962,34 @@ local function process_rules_settings()
     selt.prefix = redis_ann_prefix(rule, selt.name)
 
     rspamd_logger.messagex(rspamd_config,
-        'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"',
-        selt.prefix, selt.name, selt.digest)
+      'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"',
+      selt.prefix, selt.name, selt.digest)
 
     lua_redis.register_prefix(selt.prefix, N,
-        string.format('NN prefix for rule "%s"; settings id "%s"',
-            selt.prefix, selt.name), {
-          persistent = true,
-          type = 'zlist',
-        })
+      string.format('NN prefix for rule "%s"; settings id "%s"',
+        selt.prefix, selt.name), {
+        persistent = true,
+        type = 'zlist',
+      })
     -- Versions
     lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
-        string.format('NN storage for rule "%s"; settings id "%s"',
-            selt.prefix, selt.name), {
-          persistent = true,
-          type = 'hash',
-        })
+      string.format('NN storage for rule "%s"; settings id "%s"',
+        selt.prefix, selt.name), {
+        persistent = true,
+        type = 'hash',
+      })
     lua_redis.register_prefix(selt.prefix .. '_\\d+_spam_set', N,
-        string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
-            selt.prefix, selt.name), {
-          persistent = true,
-          type = 'set',
-        })
+      string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+        selt.prefix, selt.name), {
+        persistent = true,
+        type = 'set',
+      })
     lua_redis.register_prefix(selt.prefix .. '_\\d+_ham_set', N,
-        string.format('NN learning set (ham) for rule "%s"; settings id "%s"',
-            rule.prefix, selt.name), {
-          persistent = true,
-          type = 'set',
-        })
+      string.format('NN learning set (ham) for rule "%s"; settings id "%s"',
+        rule.prefix, selt.name), {
+        persistent = true,
+        type = 'set',
+      })
   end
 
   for k, rule in pairs(settings.rules) do
@@ -813,8 +1041,8 @@ local function process_rules_settings()
           if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
             -- Equal symbols, add reference
             lua_util.debugm(N, rspamd_config,
-                'added reference from settings id %s to %s; same symbols',
-                nelt.name, ex.name)
+              'added reference from settings id %s to %s; same symbols',
+              nelt.name, ex.name)
             rule.settings[settings_id] = id
             nelt = nil
           end
@@ -824,7 +1052,7 @@ local function process_rules_settings()
       if nelt then
         rule.settings[settings_id] = nelt
         lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
-            nelt.name, settings_id, rule.prefix)
+          nelt.name, settings_id, rule.prefix)
       end
     end
   end
@@ -847,7 +1075,7 @@ local function get_rule_settings(task, rule)
   return set
 end
 
-local function result_to_vector(task, profile)
+result_to_vector = function(task, profile)
   if not profile.zeros then
     -- Fill zeros vector
     local zeros = {}
@@ -874,13 +1102,18 @@ end
 
 return {
   can_push_train_vector = can_push_train_vector,
+  collect_features = collect_features,
   create_ann = create_ann,
   default_options = default_options,
+  build_providers_meta = build_providers_meta,
+  apply_normalization = apply_normalization,
   gen_unlock_cb = gen_unlock_cb,
   get_rule_settings = get_rule_settings,
   load_scripts = load_scripts,
   module_config = module_config,
   new_ann_key = new_ann_key,
+  providers_config_digest = providers_config_digest,
+  register_provider = register_provider,
   plugin_ver = plugin_ver,
   process_rules_settings = process_rules_settings,
   redis_ann_prefix = redis_ann_prefix,
diff --git a/lualib/plugins/neural/providers/llm.lua b/lualib/plugins/neural/providers/llm.lua
new file mode 100644 (file)
index 0000000..fda0141
--- /dev/null
@@ -0,0 +1,206 @@
+--[[
+LLM provider for neural feature fusion
+Collects text from the most relevant part and requests embeddings from an LLM API.
+Supports minimal OpenAI- and Ollama-compatible embedding endpoints.
+]] --
+
+local rspamd_http = require "rspamd_http"
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+local lua_mime = require "lua_mime"
+local neural_common = require "plugins/neural"
+local lua_cache = require "lua_cache"
+
+local N = "neural.llm"
+
+local function select_text(task, cfg)
+  local part = lua_mime.get_displayed_text_part(task)
+  if part then
+    local tp = part:get_text()
+    if tp then
+      -- Prefer UTF text content
+      local content = tp:get_content('raw_utf') or tp:get_content('raw')
+      if content and #content > 0 then
+        return content
+      end
+    end
+    -- Fallback to raw content
+    local rc = part:get_raw_content()
+    if type(rc) == 'userdata' then
+      rc = tostring(rc)
+    end
+    return rc
+  end
+
+  -- Fallback to subject if no text part
+  return task:get_subject() or ''
+end
+
+local function compose_llm_settings(pcfg)
+  local gpt_settings = rspamd_config:get_all_opt('gpt') or {}
+  local llm_type = pcfg.type or gpt_settings.type or 'openai'
+  local model = pcfg.model or gpt_settings.model
+  local timeout = pcfg.timeout or gpt_settings.timeout or 2.0
+  local url = pcfg.url
+  local api_key = pcfg.api_key or gpt_settings.api_key
+
+  if not url then
+    if llm_type == 'openai' then
+      url = 'https://api.openai.com/v1/embeddings'
+    elseif llm_type == 'ollama' then
+      url = 'http://127.0.0.1:11434/api/embeddings'
+    end
+  end
+
+  return {
+    type = llm_type,
+    model = model,
+    timeout = timeout,
+    url = url,
+    api_key = api_key,
+    cache_ttl = pcfg.cache_ttl or 86400,
+    cache_prefix = pcfg.cache_prefix or 'neural_llm',
+    cache_hash_len = pcfg.cache_hash_len or 16,
+    cache_use_hashing = pcfg.cache_use_hashing ~= false,
+  }
+end
+
+local function extract_embedding(llm_type, parsed)
+  if llm_type == 'openai' then
+    -- { data = [ { embedding = [...] } ] }
+    if parsed and parsed.data and parsed.data[1] and parsed.data[1].embedding then
+      return parsed.data[1].embedding
+    end
+  elseif llm_type == 'ollama' then
+    -- { embedding = [...] }
+    if parsed and parsed.embedding then
+      return parsed.embedding
+    end
+  end
+  return nil
+end
+
+neural_common.register_provider('llm', {
+  collect = function(task, ctx)
+    local pcfg = ctx.config or {}
+    local llm = compose_llm_settings(pcfg)
+
+    if not llm.model then
+      rspamd_logger.debugm(N, task, 'llm provider missing model; skip')
+      return nil
+    end
+
+    local content = select_text(task, pcfg)
+    if not content or #content == 0 then
+      rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip')
+      return nil
+    end
+
+    local body
+    if llm.type == 'openai' then
+      body = { model = llm.model, input = content }
+    elseif llm.type == 'ollama' then
+      body = { model = llm.model, prompt = content }
+    else
+      rspamd_logger.debugm(N, task, 'unsupported llm type: %s', llm.type)
+      return nil
+    end
+
+    -- Redis cache: use content hash + model + provider as key
+    local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, {
+      cache_prefix = llm.cache_prefix,
+      cache_ttl = llm.cache_ttl,
+      cache_format = 'messagepack',
+      cache_hash_len = llm.cache_hash_len,
+      cache_use_hashing = llm.cache_use_hashing,
+    }, N)
+
+    -- Use a stable key based on content digest
+    local hasher = require 'rspamd_cryptobox_hash'
+    local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(content):hex())
+
+    local function do_request_and_cache()
+      local headers = { ['Content-Type'] = 'application/json' }
+      if llm.type == 'openai' and llm.api_key then
+        headers['Authorization'] = 'Bearer ' .. llm.api_key
+      end
+
+      local http_params = {
+        url = llm.url,
+        mime_type = 'application/json',
+        timeout = llm.timeout,
+        log_obj = task,
+        headers = headers,
+        body = ucl.to_format(body, 'json-compact', true),
+        task = task,
+        method = 'POST',
+        use_gzip = true,
+      }
+
+      local err, data = rspamd_http.request(http_params)
+      if err then
+        rspamd_logger.debugm(N, task, 'llm request failed: %s', err)
+        return nil
+      end
+
+      local parser = ucl.parser()
+      local ok, perr = parser:parse_string(data.content)
+      if not ok then
+        rspamd_logger.debugm(N, task, 'cannot parse llm response: %s', perr)
+        return nil
+      end
+
+      local parsed = parser:get_object()
+      local embedding = extract_embedding(llm.type, parsed)
+      if not embedding or #embedding == 0 then
+        rspamd_logger.debugm(N, task, 'no embedding in llm response')
+        return nil
+      end
+
+      for i = 1, #embedding do
+        embedding[i] = tonumber(embedding[i]) or 0.0
+      end
+
+      lua_cache.cache_set(task, key, { e = embedding }, cache_ctx)
+      return embedding
+    end
+
+    -- Try cache first
+    local cached_result
+    local done = false
+    lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0,
+      function(_)
+        -- Uncached: perform request synchronously and store
+        cached_result = do_request_and_cache()
+        done = true
+      end,
+      function(_, err, data)
+        if data and data.e then
+          cached_result = data.e
+        end
+        done = true
+      end
+    )
+
+    if not done then
+      -- Fallback: ensure we still do the request now (cache API is async-ready, but we need sync path)
+      cached_result = do_request_and_cache()
+    end
+
+    local embedding = cached_result
+    if not embedding then
+      return nil
+    end
+
+    local meta = {
+      name = pcfg.name or 'llm',
+      type = 'llm',
+      dim = #embedding,
+      weight = pcfg.weight or 1.0,
+      model = llm.model,
+      provider = llm.type,
+    }
+
+    return embedding, meta
+  end
+})
diff --git a/lualib/plugins/neural/providers/symbols.lua b/lualib/plugins/neural/providers/symbols.lua
new file mode 100644 (file)
index 0000000..6a3b750
--- /dev/null
@@ -0,0 +1,10 @@
+-- Symbols provider: wraps legacy symbols+metatokens vectorization
+
+local neural_common = require "plugins/neural"
+
+neural_common.register_provider('symbols', {
+  collect = function(task, ctx)
+    local vec = neural_common.result_to_vector(task, ctx.profile)
+    return vec, { name = 'symbols', type = 'symbols', dim = #vec, weight = ctx.weight or 1.0 }
+  end
+})
index 7ea7dc2e58a61eb4293312b0514d51acebfab3d9..dfed2e358f90933d46c85c06d429a8399dd7c969 100644 (file)
@@ -9,6 +9,8 @@
 -- key7 - old key
 -- key8 - ROC Thresholds
 -- key9 - optional PCA
+-- key10 - optional providers_meta (JSON)
+-- key11 - optional norm_stats (JSON)
 local now = tonumber(KEYS[6])
 redis.call('ZADD', KEYS[2], now, KEYS[4])
 redis.call('HSET', KEYS[1], 'ann', KEYS[3])
@@ -16,10 +18,16 @@ redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8])
 if KEYS[9] then
   redis.call('HSET', KEYS[1], 'pca', KEYS[9])
 end
+if KEYS[10] then
+  redis.call('HSET', KEYS[1], 'providers_meta', KEYS[10])
+end
+if KEYS[11] then
+  redis.call('HSET', KEYS[1], 'norm_stats', KEYS[11])
+end
 redis.call('HDEL', KEYS[1], 'lock')
 redis.call('HDEL', KEYS[7], 'lock')
 redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
- -- expire in 10m, to not face race condition with other rspamd replicas refill deleted keys
+-- expire in 10m, to not face race condition with other rspamd replicas refill deleted keys
 redis.call('EXPIRE', KEYS[7] .. '_spam_set', 600)
 redis.call('EXPIRE', KEYS[7] .. '_ham_set', 600)
 return 1
index ea40fc4f7417ac248c41464ddfb40ad592f69bef..0a8ebcd6926d4d219dcdb5cd061055a333c7eb68 100644 (file)
@@ -12,7 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
-]]--
+]] --
 
 
 if confighelp then
@@ -30,6 +30,9 @@ local rspamd_tensor = require "rspamd_tensor"
 local rspamd_text = require "rspamd_text"
 local rspamd_util = require "rspamd_util"
 local ts = require("tableshape").types
+-- Load providers
+pcall(require, "plugins/neural/providers/llm")
+pcall(require, "plugins/neural/providers/symbols")
 
 local N = "neural"
 
@@ -41,6 +44,7 @@ local redis_profile_schema = ts.shape {
   version = ts.number,
   redis_key = ts.string,
   distance = ts.number:is_optional(),
+  providers_digest = ts.string:is_optional(),
 }
 
 local has_blas = rspamd_tensor.has_blas()
@@ -55,7 +59,8 @@ local function new_ann_profile(task, rule, set, version)
     redis_key = ann_key,
     version = version,
     digest = set.digest,
-    distance = 0 -- Since we are using our own profile
+    distance = 0, -- Since we are using our own profile
+    providers_digest = neural_common.providers_config_digest(rule.providers),
   }
 
   local ucl = require "ucl"
@@ -64,20 +69,20 @@ local function new_ann_profile(task, rule, set, version)
   local function add_cb(err, _)
     if err then
       rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
-          rule.prefix, set.name, profile.redis_key, err)
+        rule.prefix, set.name, profile.redis_key, err)
     else
       rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
-          rule.prefix, set.name, profile.redis_key)
+        rule.prefix, set.name, profile.redis_key)
     end
   end
 
   lua_redis.redis_make_request(task,
-      rule.redis,
-      nil,
-      true, -- is write
-      add_cb, --callback
-      'ZADD', -- command
-      { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+    rule.redis,
+    nil,
+    true,   -- is write
+    add_cb, --callback
+    'ZADD', -- command
+    { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
   )
 
   return profile
@@ -86,7 +91,6 @@ end
 
 -- ANN filter function, used to insert scores based on the existing symbols
 local function ann_scores_filter(task)
-
   for _, rule in pairs(settings.rules) do
     local sid = task:get_settings_id() or -1
     local ann
@@ -99,24 +103,41 @@ local function ann_scores_filter(task)
         profile = set.ann
       else
         lua_util.debugm(N, task, 'no ann loaded for %s:%s',
-            rule.prefix, set.name)
+          rule.prefix, set.name)
       end
     else
       lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
-          rule.prefix, sid)
+        rule.prefix, sid)
     end
 
     if ann then
-      local vec = neural_common.result_to_vector(task, profile)
+      local vec
+      if rule.providers and #rule.providers > 0 then
+        local fused, meta = neural_common.collect_features(task, rule, profile)
+        vec = fused
+        if profile.providers_digest and meta.digest and profile.providers_digest ~= meta.digest then
+          lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply',
+            rule.prefix, set.name)
+          vec = nil
+        end
+      else
+        vec = neural_common.result_to_vector(task, profile)
+      end
 
       local score
+      if not vec then
+        goto continue_rule
+      end
+      if set.ann.norm_stats then
+        vec = neural_common.apply_normalization(vec, set.ann.norm_stats)
+      end
       local out = ann:apply1(vec, set.ann.pca)
       score = out[1]
 
       local symscore = string.format('%.3f', score)
       task:cache_set(rule.prefix .. '_neural_score', score)
       lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
-          rule.prefix, set.name, set.ann.version, symscore)
+        rule.prefix, set.name, set.ann.version, symscore)
 
       if score > 0 then
         local result = score
@@ -137,8 +158,8 @@ local function ann_scores_filter(task)
           end
         else
           lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
-              rule.prefix, set.name, set.ann.version, symscore,
-              spam_threshold)
+            rule.prefix, set.name, set.ann.version, symscore,
+            spam_threshold)
         end
       else
         local result = -(score)
@@ -159,11 +180,12 @@ local function ann_scores_filter(task)
           end
         else
           lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
-              rule.prefix, set.name, set.ann.version, result,
-              ham_threshold)
+            rule.prefix, set.name, set.ann.version, result,
+            ham_threshold)
         end
       end
     end
+    ::continue_rule::
   end
 end
 
@@ -178,14 +200,14 @@ local function ann_push_task_result(rule, task, verdict, score, set)
 
       if not learn_spam then
         skip_reason = string.format('score < spam_score: %f < %f',
-            score, train_opts.spam_score)
+          score, train_opts.spam_score)
       end
     else
       learn_spam = verdict == 'spam' or verdict == 'junk'
 
       if not learn_spam then
         skip_reason = string.format('verdict: %s',
-            verdict)
+          verdict)
       end
     end
 
@@ -193,14 +215,14 @@ local function ann_push_task_result(rule, task, verdict, score, set)
       learn_ham = score <= train_opts.ham_score
       if not learn_ham then
         skip_reason = string.format('score > ham_score: %f > %f',
-            score, train_opts.ham_score)
+          score, train_opts.ham_score)
       end
     else
       learn_ham = verdict == 'ham'
 
       if not learn_ham then
         skip_reason = string.format('verdict: %s',
-            verdict)
+          verdict)
       end
     end
   else
@@ -221,7 +243,16 @@ local function ann_push_task_result(rule, task, verdict, score, set)
       learn_spam = false
 
       -- Explicitly store tokens in cache
-      local vec = neural_common.result_to_vector(task, set)
+      local vec
+      if rule.providers and #rule.providers > 0 then
+        local fused = neural_common.collect_features(task, rule, set, 'train')
+        if type(fused) == 'table' then
+          vec = fused
+        end
+      end
+      if not vec then
+        vec = neural_common.result_to_vector(task, set)
+      end
       task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
       task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest)
       skip_reason = 'store_pool_only has been set'
@@ -241,7 +272,16 @@ local function ann_push_task_result(rule, task, verdict, score, set)
         local nspam, nham = data[1], data[2]
 
         if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
-          local vec = neural_common.result_to_vector(task, set)
+          local vec
+          if rule.providers and #rule.providers > 0 then
+            local fused = neural_common.collect_features(task, rule, set)
+            if type(fused) == 'table' then
+              vec = fused
+            end
+          end
+          if not vec then
+            vec = neural_common.result_to_vector(task, set)
+          end
 
           local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
           local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
@@ -249,41 +289,41 @@ local function ann_push_task_result(rule, task, verdict, score, set)
           local function learn_vec_cb(redis_err)
             if redis_err then
               rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
-                  rule.prefix, set.name, redis_err)
+                rule.prefix, set.name, redis_err)
             else
               lua_util.debugm(N, task,
-                  "add train data for ANN rule " ..
-                      "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
-                  rule.prefix, set.name, learn_type, #vec, target_key, #str)
+                "add train data for ANN rule " ..
+                "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+                rule.prefix, set.name, learn_type, #vec, target_key, #str)
             end
           end
 
           lua_redis.redis_make_request(task,
-              rule.redis,
-              nil,
-              true, -- is write
-              learn_vec_cb, --callback
-              'SADD', -- command
-              { target_key, str } -- arguments
+            rule.redis,
+            nil,
+            true,               -- is write
+            learn_vec_cb,       --callback
+            'SADD',             -- command
+            { target_key, str } -- arguments
           )
         else
           lua_util.debugm(N, task,
-              "do not add %s train data for ANN rule " ..
-                  "%s:%s",
-              learn_type, rule.prefix, set.name)
+            "do not add %s train data for ANN rule " ..
+            "%s:%s",
+            learn_type, rule.prefix, set.name)
         end
       else
         if err then
           rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
-              rule.prefix, set.name, err)
+            rule.prefix, set.name, err)
         elseif type(data) == 'string' then
           -- nil return value
           rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
-              learn_type, rule.prefix, set.name, set.ann.redis_key, data)
+            learn_type, rule.prefix, set.name, set.ann.redis_key, data)
         else
           rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
-              'please remove this key from Redis manually if you perform upgrade from the previous version',
-              rule.prefix, set.name, set.ann.redis_key, type(data))
+            'please remove this key from Redis manually if you perform upgrade from the previous version',
+            rule.prefix, set.name, set.ann.redis_key, type(data))
         end
       end
     end
@@ -294,25 +334,25 @@ local function ann_push_task_result(rule, task, verdict, score, set)
         -- Need to create or load a profile corresponding to the current configuration
         set.ann = new_ann_profile(task, rule, set, 0)
         lua_util.debugm(N, task,
-            'requested new profile for %s, set.ann is missing',
-            set.name)
+          'requested new profile for %s, set.ann is missing',
+          set.name)
       end
 
       lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
-          { task = task, is_write = false },
-          vectors_len_cb,
-          {
-            set.ann.redis_key,
-          })
+        { task = task, is_write = false },
+        vectors_len_cb,
+        {
+          set.ann.redis_key,
+        })
     else
       lua_util.debugm(N, task,
-          'do not push data: train condition not satisfied; reason: not checked existing ANNs')
+        'do not push data: train condition not satisfied; reason: not checked existing ANNs')
     end
   else
     lua_util.debugm(N, task,
-        'do not push data to key %s: train condition not satisfied; reason: %s',
-        (set.ann or {}).redis_key,
-        skip_reason)
+      'do not push data to key %s: train condition not satisfied; reason: %s',
+      (set.ann or {}).redis_key,
+      skip_reason)
   end
 end
 
@@ -337,23 +377,29 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   local function redis_ham_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
-          ann_key, err)
+        ann_key, err)
       -- Unlock on error
       lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          true, -- is write
-          neural_common.gen_unlock_cb(rule, set, ann_key), --callback
-          'HDEL', -- command
-          { ann_key, 'lock' }
+        rspamd_config,
+        rule.redis,
+        nil,
+        true,                                            -- is write
+        neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+        'HDEL',                                          -- command
+        { ann_key, 'lock' }
       )
     else
       -- Decompress and convert to numbers each training vector
       ham_elts = process_training_vectors(data)
-      neural_common.spawn_train({ worker = worker, ev_base = ev_base,
-                                  rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts,
-                                  spam_vec = spam_elts })
+      neural_common.spawn_train({
+        worker = worker,
+        ev_base = ev_base,
+        rule = rule,
+        set = set,
+        ann_key = ann_key,
+        ham_vec = ham_elts,
+        spam_vec = spam_elts
+      })
     end
   end
 
@@ -361,29 +407,29 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   local function redis_spam_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
-          ann_key, err)
+        ann_key, err)
       -- Unlock ANN on error
       lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          true, -- is write
-          neural_common.gen_unlock_cb(rule, set, ann_key), --callback
-          'HDEL', -- command
-          { ann_key, 'lock' }
+        rspamd_config,
+        rule.redis,
+        nil,
+        true,                                            -- is write
+        neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+        'HDEL',                                          -- command
+        { ann_key, 'lock' }
       )
     else
       -- Decompress and convert to numbers each training vector
       spam_elts = process_training_vectors(data)
       -- Now get ham vectors...
       lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          false, -- is write
-          redis_ham_cb, --callback
-          'SMEMBERS', -- command
-          { ann_key .. '_ham_set' }
+        rspamd_config,
+        rule.redis,
+        nil,
+        false,        -- is write
+        redis_ham_cb, --callback
+        'SMEMBERS',   -- command
+        { ann_key .. '_ham_set' }
       )
     end
   end
@@ -391,33 +437,33 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   local function redis_lock_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
-          ann_key, err)
+        ann_key, err)
     elseif type(data) == 'number' and data == 1 then
       -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
       lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          false, -- is write
-          redis_spam_cb, --callback
-          'SMEMBERS', -- command
-          { ann_key .. '_spam_set' }
+        rspamd_config,
+        rule.redis,
+        nil,
+        false,         -- is write
+        redis_spam_cb, --callback
+        'SMEMBERS',    -- command
+        { ann_key .. '_spam_set' }
       )
 
       rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
-          rule.prefix, set.name, ann_key)
+        rule.prefix, set.name, ann_key)
     else
       local lock_tm = tonumber(data[1])
       rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
-          'locked by another host %s at %s', rule.prefix, set.name, ann_key,
-          data[2], os.date('%c', lock_tm))
+        'locked by another host %s at %s', rule.prefix, set.name, ann_key,
+        data[2], os.date('%c', lock_tm))
     end
   end
 
   -- Check if we are already learning this network
   if set.learning_spawned then
     rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
-        ann_key)
+      ann_key)
     return
   end
 
@@ -425,14 +471,14 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
   -- ANN is locked by another host (or a process, meh)
   lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
-      { ev_base = ev_base, is_write = true },
-      redis_lock_cb,
-      {
-        ann_key,
-        tostring(os.time()),
-        tostring(math.max(10.0, rule.watch_interval * 2)),
-        rspamd_util.get_hostname()
-      })
+    { ev_base = ev_base, is_write = true },
+    redis_lock_cb,
+    {
+      ann_key,
+      tostring(os.time()),
+      tostring(math.max(10.0, rule.watch_interval * 2)),
+      rspamd_util.get_hostname()
+    })
 end
 
 -- This function loads new ann from Redis
@@ -447,7 +493,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
   local function data_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
-          ann_key, err)
+        ann_key, err)
     else
       if type(data) == 'table' then
         if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
@@ -456,7 +502,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
 
           if _err or not ann_data then
             rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
-                rule.prefix .. ':' .. set.name, ann_key, _err)
+              rule.prefix .. ':' .. set.name, ann_key, _err)
             return
           else
             ann = rspamd_kann.load(ann_data)
@@ -467,7 +513,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
                 version = profile.version,
                 symbols = profile.symbols,
                 distance = min_diff,
-                redis_key = profile.redis_key
+                redis_key = profile.redis_key,
+                providers_digest = profile.providers_digest,
               }
 
               local ucl = require "ucl"
@@ -479,26 +526,26 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
               end
               -- Also update rank for the loaded ANN to avoid removal
               lua_redis.redis_make_request_taskless(ev_base,
-                  rspamd_config,
-                  rule.redis,
-                  nil,
-                  true, -- is write
-                  rank_cb, --callback
-                  'ZADD', -- command
-                  { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+                rspamd_config,
+                rule.redis,
+                nil,
+                true,    -- is write
+                rank_cb, --callback
+                'ZADD',  -- command
+                { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
               )
               rspamd_logger.infox(rspamd_config,
-                  'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
-                  rule.prefix, set.name, ann_key, #data[1], profile.version)
+                'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
+                rule.prefix, set.name, ann_key, #data[1], profile.version)
             else
               rspamd_logger.errx(rspamd_config,
-                  'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
-                  rule.prefix, set.name, ann_key)
+                'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
+                rule.prefix, set.name, ann_key)
             end
           end
         else
           lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
-              rule.prefix, set.name, ann_key)
+            rule.prefix, set.name, ann_key)
         end
 
         if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
@@ -510,8 +557,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
             local roc_thresholds = parser:get_object()
             set.ann.roc_thresholds = roc_thresholds
             rspamd_logger.infox(rspamd_config,
-                'loaded ROC thresholds for %s:%s; version=%s',
-                rule.prefix, set.name, profile.version)
+              'loaded ROC thresholds for %s:%s; version=%s',
+              rule.prefix, set.name, profile.version)
             rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
           end
         end
@@ -524,19 +571,19 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
               -- We can use PCA
               set.ann.pca = rspamd_tensor.load(pca_data)
               rspamd_logger.infox(rspamd_config,
-                  'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
-                  rule.prefix, set.name, ann_key, #data[3], profile.version)
+                'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
+                rule.prefix, set.name, ann_key, #data[3], profile.version)
             else
               -- no need in pca, why is it there?
               rspamd_logger.warnx(rspamd_config,
-                  'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
-                  rule.prefix, set.name, ann_key)
+                'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
+                rule.prefix, set.name, ann_key)
             end
           else
             -- pca can be missing merely if we have no max_inputs
             if rule.max_inputs then
               rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
-                  rule.prefix, set.name, ann_key, _err)
+                rule.prefix, set.name, ann_key, _err)
               set.ann.ann = nil
             else
               -- It is okay
@@ -545,21 +592,39 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
           end
         end
 
+        -- Providers meta (optional)
+        if set.ann and set.ann.ann and type(data[4]) == 'userdata' and data[4].cookie == text_cookie then
+          local ucl = require "ucl"
+          local parser = ucl.parser()
+          local ok = parser:parse_text(data[4])
+          if ok then
+            set.ann.providers_meta = parser:get_object()
+          end
+        end
+        -- Normalization stats (optional)
+        if set.ann and set.ann.ann and type(data[5]) == 'userdata' and data[5].cookie == text_cookie then
+          local ucl = require "ucl"
+          local parser = ucl.parser()
+          local ok = parser:parse_text(data[5])
+          if ok then
+            set.ann.norm_stats = parser:get_object()
+          end
+        end
       else
         lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
-            rule.prefix, set.name, ann_key)
+          rule.prefix, set.name, ann_key)
       end
     end
   end
   lua_redis.redis_make_request_taskless(ev_base,
-      rspamd_config,
-      rule.redis,
-      nil,
-      false, -- is write
-      data_cb, --callback
-      'HMGET', -- command
-      { ann_key, 'ann', 'roc_thresholds', 'pca' }, -- arguments
-      { opaque_data = true }
+    rspamd_config,
+    rule.redis,
+    nil,
+    false,                                                                       -- is write
+    data_cb,                                                                     --callback
+    'HMGET',                                                                     -- command
+    { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments
+    { opaque_data = true }
   )
 end
 
@@ -595,34 +660,34 @@ local function process_existing_ann(_, ev_base, rule, set, profiles)
         if set.ann.version < sel_elt.version then
           -- Load new ann
           rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
-              'our version = %s, remote version = %s',
-              rule.prefix .. ':' .. set.name,
-              set.ann.version,
-              sel_elt.version)
+            'our version = %s, remote version = %s',
+            rule.prefix .. ':' .. set.name,
+            set.ann.version,
+            sel_elt.version)
           load_new_ann(rule, ev_base, set, sel_elt, min_diff)
         else
           lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
-              'our version = %s, remote version = %s',
-              rule.prefix .. ':' .. set.name,
-              set.ann.version,
-              sel_elt.version)
+            'our version = %s, remote version = %s',
+            rule.prefix .. ':' .. set.name,
+            set.ann.version,
+            sel_elt.version)
         end
       else
         -- We have some different ANN, so we need to compare distance
         if set.ann.distance > min_diff then
           -- Load more specific ANN
           rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
-              'our distance = %s, remote distance = %s',
-              rule.prefix .. ':' .. set.name,
-              set.ann.distance,
-              min_diff)
+            'our distance = %s, remote distance = %s',
+            rule.prefix .. ':' .. set.name,
+            set.ann.distance,
+            min_diff)
           load_new_ann(rule, ev_base, set, sel_elt, min_diff)
         else
           lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
-              'our distance = %s, remote distance = %s',
-              rule.prefix .. ':' .. set.name,
-              set.ann.distance,
-              min_diff)
+            'our distance = %s, remote distance = %s',
+            rule.prefix .. ':' .. set.name,
+            set.ann.distance,
+            min_diff)
         end
       end
     else
@@ -660,14 +725,14 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
     local ann_key = sel_elt.redis_key
 
     lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
-        ann_key)
+      ann_key)
 
     -- Create continuation closure
     local redis_len_cb_gen = function(cont_cb, what, is_final)
       return function(err, data)
         if err then
           rspamd_logger.errx(rspamd_config,
-              'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
+            'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
         elseif data and type(data) == 'number' or type(data) == 'string' then
           local ntrains = tonumber(data) or 0
           lens[what] = ntrains
@@ -688,67 +753,65 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
               end
               if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
                 lua_util.debugm(N, rspamd_config,
-                    'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                    ann_key, lens, rule.train.max_trains, what)
+                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+                  ann_key, lens, rule.train.max_trains, what)
                 cont_cb()
               else
                 lua_util.debugm(N, rspamd_config,
-                    'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                    ann_key, what, lens, rule.train.max_trains)
+                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+                  ann_key, what, lens, rule.train.max_trains)
               end
             else
               -- Probabilistic mode, just ensure that at least one vector is okay
               if min_len > 0 and max_len >= rule.train.max_trains then
                 lua_util.debugm(N, rspamd_config,
-                    'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                    ann_key, lens, rule.train.max_trains, what)
+                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+                  ann_key, lens, rule.train.max_trains, what)
                 cont_cb()
               else
                 lua_util.debugm(N, rspamd_config,
-                    'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                    ann_key, what, lens, rule.train.max_trains)
+                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+                  ann_key, what, lens, rule.train.max_trains)
               end
             end
-
           else
             lua_util.debugm(N, rspamd_config,
-                'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
-                what, ann_key, ntrains, rule.train.max_trains)
+              'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
+              what, ann_key, ntrains, rule.train.max_trains)
             cont_cb()
           end
         end
       end
-
     end
 
     local function initiate_train()
       rspamd_logger.infox(rspamd_config,
-          'need to learn ANN %s after %s required learn vectors',
-          ann_key, lens)
+        'need to learn ANN %s after %s required learn vectors',
+        ann_key, lens)
       do_train_ann(worker, ev_base, rule, set, ann_key)
     end
 
     -- Spam vector is OK, check ham vector length
     local function check_ham_len()
       lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          false, -- is write
-          redis_len_cb_gen(initiate_train, 'ham', true), --callback
-          'SCARD', -- command
-          { ann_key .. '_ham_set' }
+        rspamd_config,
+        rule.redis,
+        nil,
+        false,                                         -- is write
+        redis_len_cb_gen(initiate_train, 'ham', true), --callback
+        'SCARD',                                       -- command
+        { ann_key .. '_ham_set' }
       )
     end
 
     lua_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        false, -- is write
-        redis_len_cb_gen(check_ham_len, 'spam', false), --callback
-        'SCARD', -- command
-        { ann_key .. '_spam_set' }
+      rspamd_config,
+      rule.redis,
+      nil,
+      false,                                          -- is write
+      redis_len_cb_gen(check_ham_len, 'spam', false), --callback
+      'SCARD',                                        -- command
+      { ann_key .. '_spam_set' }
     )
   end
 end
@@ -761,7 +824,7 @@ local function load_ann_profile(element)
   local res, ucl_err = parser:parse_string(element)
   if not res then
     rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
-        ucl_err)
+      ucl_err)
     return nil
   else
     local profile = parser:get_object()
@@ -781,11 +844,11 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
     local function members_cb(err, data)
       if err then
         rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
-            err)
+          err)
         set.can_store_vectors = true
       elseif type(data) == 'table' then
         lua_util.debugm(N, cfg, '%s: process element %s:%s',
-            what, rule.prefix, set.name)
+          what, rule.prefix, set.name)
         process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
         set.can_store_vectors = true
       end
@@ -797,13 +860,13 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
       -- Select the most appropriate to our profile but it should not differ by more
       -- than 30% of symbols
       lua_redis.redis_make_request_taskless(ev_base,
-          cfg,
-          rule.redis,
-          nil,
-          false, -- is write
-          members_cb, --callback
-          'ZREVRANGE', -- command
-          { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
+        cfg,
+        rule.redis,
+        nil,
+        false,                                               -- is write
+        members_cb,                                          --callback
+        'ZREVRANGE',                                         -- command
+        { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
       )
     end
   end -- Cycle over all settings
@@ -817,23 +880,23 @@ local function cleanup_anns(rule, cfg, ev_base)
     local function invalidate_cb(err, data)
       if err then
         rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
-            err)
+          err)
       elseif type(data) == 'table' then
         for _, expired in ipairs(data) do
           local profile = load_ann_profile(expired)
           rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
-              rule.prefix .. ':' .. set.name,
-              profile.redis_key,
-              profile.version)
+            rule.prefix .. ':' .. set.name,
+            profile.redis_key,
+            profile.version)
         end
       end
     end
 
     if type(set) == 'table' then
       lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
-          { ev_base = ev_base, is_write = true },
-          invalidate_cb,
-          { set.prefix, tostring(settings.max_profiles) })
+        { ev_base = ev_base, is_write = true },
+        invalidate_cb,
+        { set.prefix, tostring(settings.max_profiles) })
     end
   end
 end
@@ -852,14 +915,14 @@ local function ann_push_vector(task)
 
   if verdict == 'passthrough' then
     lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
-        verdict, score)
+      verdict, score)
 
     return
   end
 
   if score ~= score then
     lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
-        verdict)
+      verdict)
 
     return
   end
@@ -872,7 +935,6 @@ local function ann_push_vector(task)
     else
       lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
     end
-
   end
 end
 
@@ -930,10 +992,23 @@ for k, r in pairs(rules) do
 
   if rule_elt.max_inputs and not has_blas then
     rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in',
-        rule_elt.name, rule_elt.max_inputs)
+      rule_elt.name, rule_elt.max_inputs)
     rule_elt.max_inputs = nil
   end
 
+  -- Phase 4: basic provider config validation
+  if rule_elt.providers and #rule_elt.providers > 0 then
+    for i, pcfg in ipairs(rule_elt.providers) do
+      if not (pcfg.type or pcfg.name) then
+        rspamd_logger.errx(rspamd_config, 'provider at index %s in rule %s has no type/name; will be ignored', i, k)
+      end
+      if (pcfg.type == 'llm' or pcfg.name == 'llm') and not (pcfg.model or (rspamd_config:get_all_opt('gpt') or {}).model) then
+        rspamd_logger.errx(rspamd_config,
+          'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
+      end
+    end
+  end
+
   rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
   settings.rules[k] = rule_elt
   rspamd_config:set_metric_symbol({
@@ -980,21 +1055,21 @@ for _, rule in pairs(settings.rules) do
   rspamd_config:add_on_load(function(cfg, ev_base, worker)
     if worker:is_scanner() then
       rspamd_config:add_periodic(ev_base, 0.0,
-          function(_, _)
-            return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
-                'try_load_ann')
-          end)
+        function(_, _)
+          return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
+            'try_load_ann')
+        end)
     end
 
     if worker:is_primary_controller() then
       -- We also want to train neural nets when they have enough data
       rspamd_config:add_periodic(ev_base, 0.0,
-          function(_, _)
-            -- Clean old ANNs
-            cleanup_anns(rule, cfg, ev_base)
-            return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
-                'try_train_ann')
-          end)
+        function(_, _)
+          -- Clean old ANNs
+          cleanup_anns(rule, cfg, ev_base)
+          return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
+            'try_train_ann')
+        end)
     end
   end)
 end