return registered_providers[name]
end
+-- ANN architecture registry. An architecture is a builder
+-- function(n_inputs, rule) -> kann object
+-- that turns an input vector of size n_inputs into a compiled network. The
+-- built-in 'symbol', 'embedding' and 'conv1d' architectures are registered
+-- below; third-party modules can register their own (e.g. attention pooling)
+-- via the public register_architecture API and select them with
+-- `rule.architecture = "<name>"`.
+local registered_architectures = {}
+
+--- Registers an ANN architecture builder
+-- @param name string
+-- @param builder function(n, rule) -> kann object
+local function register_architecture(name, builder)
+ registered_architectures[name] = builder
+end
+
+local function get_architecture(name)
+ return registered_architectures[name]
+end
+
-- Forward declaration
local result_to_vector
return create_embedding_ann(n, rule)
end
+-- Attention ANN: learned multi-head attention pooling over a sequence of
+-- word vectors, followed by a dense head on the pooled representation.
+-- The sequence provider (output_mode = "sequence") must come FIRST in the
+-- input vector. Anything after it (metatokens, other providers) is routed
+-- around the attention layer and concatenated with the pooled output before
+-- the dense head (late fusion). For such a hybrid layout, set
+-- attention.channels to the per-word dimension so that the sequence length
+-- can be derived; with no tail (fusion.include_meta = false and a single
+-- provider) channels is derived from the input size.
-- Detects if rule input contains dense embedding features: any provider other
-- than plain symbols/metatokens (llm, fasttext_embed, text_hash, ...).
-- Such inputs need the embedding architecture and a lower learning rate:
return false
end
--- Main ANN factory function - auto-selects architecture based on rule configuration
-local function create_ann(n, nlayers, rule)
- -- Check for conv1d architecture first
+-- Resolves the architecture name for a rule when not set explicitly. Keeps the
+-- historical auto-selection so existing configs (no `architecture` field) build
+-- the same network as before.
+local function default_architecture(rule)
if rule.conv1d then
- lua_util.debugm(N, rspamd_config, 'creating conv1d ANN with %s inputs', n)
- return create_conv1d_ann(n, rule)
+ return 'conv1d'
+ end
+ if uses_dense_features(rule) or rule.layers ~= nil
+ or rule.use_layernorm ~= nil or rule.dropout ~= nil then
+ return 'embedding'
end
+ return 'symbol'
+end
- -- Check if we should use the enhanced embedding architecture
- -- Conditions: any dense feature provider, or explicit multi-layer config
- local use_embedding_arch = uses_dense_features(rule)
- or rule.layers ~= nil
- or rule.use_layernorm ~= nil
- or rule.dropout ~= nil
+-- Built-in architectures. Third-party modules register their own via
+-- register_architecture and select them with `rule.architecture = "<name>"`.
+register_architecture('symbol', create_symbol_ann)
+register_architecture('embedding', create_embedding_ann)
+register_architecture('conv1d', create_conv1d_ann)
- if use_embedding_arch then
- lua_util.debugm(N, rspamd_config, 'creating multi-layer embedding ANN with %s inputs', n)
- return create_embedding_ann(n, rule)
- else
- lua_util.debugm(N, rspamd_config, 'creating simple symbol ANN with %s inputs', n)
- return create_symbol_ann(n, rule)
+-- Main ANN factory: dispatches to a registered architecture builder. An
+-- explicit `rule.architecture` wins; otherwise the architecture is auto-
+-- selected from the rule shape for backward compatibility.
+local function create_ann(n, nlayers, rule)
+ local arch = rule.architecture or default_architecture(rule)
+ local builder = get_architecture(arch)
+ if not builder then
+ error(string.format('unknown neural architecture %q for rule %s ' ..
+ '(is the module providing it loaded?)', tostring(arch), rule.prefix or '?'))
end
+ lua_util.debugm(N, rspamd_config, 'creating %s ANN with %s inputs', arch, n)
+ return builder(n, rule)
end
-- Fills ANN data for a specific settings element
pending_train_key = pending_train_key,
providers_config_digest = providers_config_digest,
register_provider = register_provider,
+ register_architecture = register_architecture,
+ get_architecture = get_architecture,
plugin_ver = plugin_ver,
process_rules_settings = process_rules_settings,
redis_ann_prefix = redis_ann_prefix,