]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] neural: pluggable feature-provider and ANN-architecture registries
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 13 Jun 2026 12:18:01 +0000 (13:18 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 13 Jun 2026 12:18:01 +0000 (13:18 +0100)
Turn the neural plugin into an extension point so third-party (including
closed-source) modules can add feature providers and network topologies
without patching the core.

* register_architecture(name, builder) / get_architecture(name): a
  registry of ANN builders, function(n_inputs, rule) -> kann network.
  The built-in 'symbol', 'embedding' and 'conv1d' architectures are now
  registered through it; create_ann() dispatches on rule.architecture
  and falls back to the historical auto-selection when it is unset, so
  existing configs are unaffected.
* register_provider (already present) and register_architecture are
  exported from the neural module, so a module that does
  require 'plugins/neural' can register a custom provider or
  architecture and select it with provider type / rule.architecture.

An unknown rule.architecture now fails loudly with a hint that the
providing module may not be loaded, instead of silently falling back.

lualib/plugins/neural.lua

index 14ebbcb223f028e9072c685f8956e42abee86372..d15dc31fb63c65d0af187f897794703b76da21c9 100644 (file)
@@ -137,6 +137,26 @@ local function get_provider(name)
   return registered_providers[name]
 end
 
+-- ANN architecture registry. An architecture is a builder
+--   function(n_inputs, rule) -> kann object
+-- that turns an input vector of size n_inputs into a compiled network. The
+-- built-in 'symbol', 'embedding' and 'conv1d' architectures are registered
+-- below; third-party modules can register their own (e.g. attention pooling)
+-- via the public register_architecture API and select them with
+-- `rule.architecture = "<name>"`.
+local registered_architectures = {}
+
+--- Registers an ANN architecture builder
+-- @param name string
+-- @param builder function(n, rule) -> kann object
+local function register_architecture(name, builder)
+  registered_architectures[name] = builder
+end
+
+local function get_architecture(name)
+  return registered_architectures[name]
+end
+
 -- Forward declaration
 local result_to_vector
 
@@ -283,6 +303,15 @@ local function create_conv1d_ann(n, rule)
   return create_embedding_ann(n, rule)
 end
 
+-- Attention ANN: learned multi-head attention pooling over a sequence of
+-- word vectors, followed by a dense head on the pooled representation.
+-- The sequence provider (output_mode = "sequence") must come FIRST in the
+-- input vector. Anything after it (metatokens, other providers) is routed
+-- around the attention layer and concatenated with the pooled output before
+-- the dense head (late fusion). For such a hybrid layout, set
+-- attention.channels to the per-word dimension so that the sequence length
+-- can be derived; with no tail (fusion.include_meta = false and a single
+-- provider) channels is derived from the input size.
 -- Detects if rule input contains dense embedding features: any provider other
 -- than plain symbols/metatokens (llm, fasttext_embed, text_hash, ...).
 -- Such inputs need the embedding architecture and a lower learning rate:
@@ -302,28 +331,38 @@ local function uses_dense_features(rule)
   return false
 end
 
--- Main ANN factory function - auto-selects architecture based on rule configuration
-local function create_ann(n, nlayers, rule)
-  -- Check for conv1d architecture first
+-- Resolves the architecture name for a rule when not set explicitly. Keeps the
+-- historical auto-selection so existing configs (no `architecture` field) build
+-- the same network as before.
+local function default_architecture(rule)
   if rule.conv1d then
-    lua_util.debugm(N, rspamd_config, 'creating conv1d ANN with %s inputs', n)
-    return create_conv1d_ann(n, rule)
+    return 'conv1d'
+  end
+  if uses_dense_features(rule) or rule.layers ~= nil
+      or rule.use_layernorm ~= nil or rule.dropout ~= nil then
+    return 'embedding'
   end
+  return 'symbol'
+end
 
-  -- Check if we should use the enhanced embedding architecture
-  -- Conditions: any dense feature provider, or explicit multi-layer config
-  local use_embedding_arch = uses_dense_features(rule)
-    or rule.layers ~= nil
-    or rule.use_layernorm ~= nil
-    or rule.dropout ~= nil
+-- Built-in architectures. Third-party modules register their own via
+-- register_architecture and select them with `rule.architecture = "<name>"`.
+register_architecture('symbol', create_symbol_ann)
+register_architecture('embedding', create_embedding_ann)
+register_architecture('conv1d', create_conv1d_ann)
 
-  if use_embedding_arch then
-    lua_util.debugm(N, rspamd_config, 'creating multi-layer embedding ANN with %s inputs', n)
-    return create_embedding_ann(n, rule)
-  else
-    lua_util.debugm(N, rspamd_config, 'creating simple symbol ANN with %s inputs', n)
-    return create_symbol_ann(n, rule)
+-- Main ANN factory: dispatches to a registered architecture builder. An
+-- explicit `rule.architecture` wins; otherwise the architecture is auto-
+-- selected from the rule shape for backward compatibility.
+local function create_ann(n, nlayers, rule)
+  local arch = rule.architecture or default_architecture(rule)
+  local builder = get_architecture(arch)
+  if not builder then
+    error(string.format('unknown neural architecture %q for rule %s ' ..
+      '(is the module providing it loaded?)', tostring(arch), rule.prefix or '?'))
   end
+  lua_util.debugm(N, rspamd_config, 'creating %s ANN with %s inputs', arch, n)
+  return builder(n, rule)
 end
 
 -- Fills ANN data for a specific settings element
@@ -1683,6 +1722,8 @@ return {
   pending_train_key = pending_train_key,
   providers_config_digest = providers_config_digest,
   register_provider = register_provider,
+  register_architecture = register_architecture,
+  get_architecture = get_architecture,
   plugin_ver = plugin_ver,
   process_rules_settings = process_rules_settings,
   redis_ann_prefix = redis_ann_prefix,