]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Multi-layer funnel architecture for LLM embeddings
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 20 Jan 2026 14:25:29 +0000 (14:25 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 20 Jan 2026 14:25:29 +0000 (14:25 +0000)
Add improved neural network architecture specifically for LLM embedding
inputs, while preserving backward compatibility for symbol-based rules.

Key changes:
- New create_embedding_ann() with multi-layer funnel architecture
- Auto-detection of LLM providers via uses_llm_embeddings()
- Support for configurable layers, dropout, layer normalization
- GELU activation by default when available (falls back to ReLU)
- Layer size auto-scaling based on input dimension:
  - >512 dims: 3 layers (0.5, 0.25, 0.125)
  - 256-512 dims: 2 layers (0.5, 0.25)
  - <256 dims: 1 layer (0.5)

Bug fixes:
- Wrap create_ann in pcall to handle errors gracefully
- Reset learning_spawned flag on ANN creation failure
- Replace assert(false) with proper error logging that resets state
- Prevents training from getting stuck after errors

New configuration options:
- layers: explicit layer size multipliers
- dropout: dropout rate (default 0.2 for embeddings)
- use_layernorm: enable layer normalization (default true)
- activation: 'gelu' or 'relu' (default 'gelu' if available)

lualib/plugins/neural.lua

index 683bf67adf21374c90ca6520657c8e4b2157e231..68bdb3c3dc698c02ee9eca5463f24f25a1369e50 100644 (file)
@@ -53,7 +53,12 @@ local default_options = {
   lock_expire = 600,
   learning_spawned = false,
   ann_expire = 60 * 60 * 24 * 2,    -- 2 days
-  hidden_layer_mult = 1.5,          -- number of neurons in the hidden layer
+  hidden_layer_mult = 1.5,          -- number of neurons in the hidden layer (symbol-based mode)
+  -- Multi-layer architecture settings (for LLM embeddings mode)
+  layers = nil,                     -- layer size multipliers (auto-computed based on input dim if nil)
+  dropout = nil,                    -- dropout rate (0.2 default for embeddings, nil=disabled for symbols)
+  use_layernorm = nil,              -- enable layer normalization (true default for embeddings)
+  activation = nil,                 -- activation function: 'relu' or 'gelu' (default: gelu for embeddings, relu for symbols)
   roc_enabled = false,              -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
   roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
   spam_score_threshold = nil,       -- neural score threshold for spam (must be 0..1 or nil to disable)
@@ -185,16 +190,112 @@ local function load_scripts()
   end
 end
 
-local function create_ann(n, nlayers, rule)
-  -- We ignore number of layers so far when using kann
+-- Creates a simple single-layer ANN for symbol-based inputs (backward compatible)
+local function create_symbol_ann(n, rule)
   local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
   local t = rspamd_kann.layer.input(n)
   t = rspamd_kann.transform.relu(t)
-  t = rspamd_kann.layer.dense(t, nhidden);
+  t = rspamd_kann.layer.dense(t, nhidden)
+  t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
+  return rspamd_kann.new.kann(t)
+end
+
+-- Creates a multi-layer funnel ANN optimized for high-dimensional embeddings
+-- Architecture: Input → [Dense → LayerNorm → Activation → Dropout]* → Cost
+local function create_embedding_ann(n, rule)
+  local t = rspamd_kann.layer.input(n)
+
+  -- Get architecture settings with smart defaults based on input dimension
+  local layers = rule.layers
+  if not layers then
+    -- Auto-compute layer sizes based on input dimension
+    if n > 512 then
+      layers = { 0.5, 0.25, 0.125 } -- 3 layers for large embeddings (e.g., 1024-dim)
+    elseif n > 256 then
+      layers = { 0.5, 0.25 }        -- 2 layers for medium embeddings
+    else
+      layers = { 0.5 }              -- 1 layer for small embeddings
+    end
+  end
+
+  local dropout_rate = rule.dropout
+  if dropout_rate == nil then
+    dropout_rate = 0.2 -- Default dropout for regularization
+  end
+
+  local use_layernorm = rule.use_layernorm
+  if use_layernorm == nil then
+    use_layernorm = true -- Default: enable layer normalization
+  end
+
+  -- Select activation function: GELU for embeddings (better for high-dim), ReLU as fallback
+  local activation = rule.activation
+  if not activation then
+    -- Default to GELU for embeddings if available
+    activation = rspamd_kann.transform.gelu and 'gelu' or 'relu'
+  end
+  local activate_fn = (activation == 'gelu' and rspamd_kann.transform.gelu) or rspamd_kann.transform.relu
+
+  lua_util.debugm(N, rspamd_config, 'embedding ANN: %s layers, dropout=%s, layernorm=%s, activation=%s',
+    #layers, dropout_rate, use_layernorm, activation)
+
+  -- Build funnel architecture with graduated dimension reduction
+  for i, layer_mult in ipairs(layers) do
+    local layer_size = math.max(math.floor(n * layer_mult), 32)
+
+    -- Dense layer
+    t = rspamd_kann.layer.dense(t, layer_size)
+
+    -- Layer normalization for training stability
+    if use_layernorm then
+      t = rspamd_kann.layer.layernorm(t)
+    end
+
+    -- Activation function (GELU or ReLU)
+    t = activate_fn(t)
+
+    -- Dropout for regularization (less on final hidden layer)
+    if dropout_rate > 0 then
+      local rate = (i == #layers) and (dropout_rate * 0.5) or dropout_rate
+      t = rspamd_kann.layer.dropout(t, rate)
+    end
+  end
+
   t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
   return rspamd_kann.new.kann(t)
 end
 
+-- Detects if rule uses LLM embeddings provider
+local function uses_llm_embeddings(rule)
+  if not rule.providers then
+    return false
+  end
+  for _, p in ipairs(rule.providers) do
+    if p.type == 'llm' then
+      return true
+    end
+  end
+  return false
+end
+
+-- Main ANN factory function - auto-selects architecture based on rule configuration
+local function create_ann(n, nlayers, rule)
+  -- Check if we should use the enhanced embedding architecture
+  -- Conditions: has LLM provider, or explicit multi-layer config, or large input dimension
+  local use_embedding_arch = uses_llm_embeddings(rule)
+    or rule.layers ~= nil
+    or rule.use_layernorm ~= nil
+    or rule.dropout ~= nil
+
+  if use_embedding_arch then
+    lua_util.debugm(N, rspamd_config, 'creating multi-layer embedding ANN with %s inputs', n)
+    return create_embedding_ann(n, rule)
+  else
+    lua_util.debugm(N, rspamd_config, 'creating simple symbol ANN with %s inputs', n)
+    return create_symbol_ann(n, rule)
+  end
+end
+
 -- Fills ANN data for a specific settings element
 local function fill_set_ann(set, ann_key)
   if not set.ann then
@@ -852,13 +953,22 @@ local function spawn_train(params)
       #params.set.symbols, meta_functions.rspamd_count_metatokens(), n)
   end
 
-  -- Now we can train ann
-  local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule)
+  -- Now we can train ann - wrap in pcall to catch KANN errors
+  local create_ok, train_ann = pcall(create_ann, params.rule.max_inputs or n, 3, params.rule)
+  if not create_ok then
+    rspamd_logger.errx(rspamd_config, 'failed to create ANN for %s:%s: %s',
+      params.rule.prefix, params.set.name, train_ann)
+    params.set.learning_spawned = false
+    return
+  end
 
   if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then
-    -- Invalidate ANN as it is definitely invalid
-    -- TODO: add invalidation
-    assert(false)
+    -- Insufficient training data, reset flag and return
+    rspamd_logger.errx(rspamd_config, 'insufficient training data for ANN %s:%s: spam=%s ham=%s (need at least %s total)',
+      params.rule.prefix, params.set.name,
+      #params.spam_vec, #params.ham_vec, params.rule.train.max_trains / 2)
+    params.set.learning_spawned = false
+    return
   else
     local inputs, outputs = {}, {}