return ann;
}
+/*******************************
+ *** @@MERGE: weight merging ***
+ ********************************/
+
+/**
+ * Merge weights from source ANN into destination ANN using linear interpolation.
+ * w_dst = (1 - alpha) * w_dst + alpha * w_src
+ *
+ * Both ANNs must have the same architecture (same number of variables and constants).
+ *
+ * @param dst destination ANN (modified in place)
+ * @param src source ANN (not modified)
+ * @param alpha weight for source ANN (0.0 - 1.0)
+ * @return 0 on success, -1 on error
+ */
+int kann_merge_weights(kann_t *dst, const kann_t *src, float alpha)
+{
+ int n_var_dst, n_var_src;
+ int n_const_dst, n_const_src;
+
+ if (!dst || !src) {
+ return -1;
+ }
+
+ n_var_dst = kad_size_var(dst->n, dst->v);
+ n_var_src = kad_size_var(src->n, src->v);
+ n_const_dst = kad_size_const(dst->n, dst->v);
+ n_const_src = kad_size_const(src->n, src->v);
+
+ /* Check architecture compatibility */
+ if (n_var_dst != n_var_src) {
+ return -1;
+ }
+
+ if (n_const_dst != n_const_src) {
+ /* Constants should match, but we can still merge weights */
+ (void) n_const_dst; /* suppress unused warning */
+ }
+
+ /* Merge variable weights (trainable parameters) */
+ for (int i = 0; i < n_var_dst; i++) {
+ dst->x[i] = (1.0f - alpha) * dst->x[i] + alpha * src->x[i];
+ }
+
+ return 0;
+}
+
+/**
+ * Check if two ANNs have compatible architecture for weight merging.
+ *
+ * @param a first ANN
+ * @param b second ANN
+ * @return 1 if compatible, 0 otherwise
+ */
+int kann_is_compatible(const kann_t *a, const kann_t *b)
+{
+ int n_var_a, n_var_b;
+
+ if (!a || !b) {
+ return 0;
+ }
+
+ n_var_a = kad_size_var(a->n, a->v);
+ n_var_b = kad_size_var(b->n, b->v);
+
+ if (n_var_a != n_var_b) {
+ return 0;
+ }
+
+ /* Check input/output dimensions */
+ if (kann_dim_in(a) != kann_dim_in(b)) {
+ return 0;
+ }
+
+ if (kann_dim_out(a) != kann_dim_out(b)) {
+ return 0;
+ }
+
+ return 1;
+}
+
/**********************************************
*** @@LAYER: layers and model generation ***
**********************************************/
kann_t *kann_load_fp(FILE *fp);
kann_t *kann_load(const char *fn);
+/* weight merging */
+int kann_merge_weights(kann_t *dst, const kann_t *src, float alpha);
+int kann_is_compatible(const kann_t *a, const kann_t *b);
+
#ifdef __cplusplus
}
#endif
--- /dev/null
+--[[
+Copyright (c) 2026, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[
+External neural model loading and merging.
+
+This module provides functionality to load pretrained neural models
+from external sources (HTTP/HTTPS) via the Maps infrastructure
+and merge them with locally trained weights.
+
+Model format (msgpack):
+{
+ magic = "RNM1", -- Rspamd Neural Model v1
+ version = 1, -- format version
+ model_version = 123, -- model training version (incremented on retrain)
+ providers_digest = "...", -- digest of providers config (must match local)
+ ann_data = "...", -- serialized KANN (zstd compressed)
+ pca_data = "...", -- optional PCA (zstd compressed)
+ norm_stats = {...}, -- normalization stats
+ roc_thresholds = {...}, -- ROC thresholds
+ created_at = timestamp,
+}
+
+Usage in neural config:
+ external_model = {
+ url = "https://your-provider.com/models/<digest>";
+ sign_key = "your_key"; -- optional signature verification
+ merge_alpha = 0.6; -- 60% external, 40% local
+ };
+]]--
+
+local lua_redis = require "lua_redis"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_logger = require "rspamd_logger"
+local rspamd_util = require "rspamd_util"
+local rspamd_text = require "rspamd_text"
+local ucl = require "ucl"
+
+-- Model format constants
+local MODEL_MAGIC = "RNM1"
+local MODEL_FORMAT_VERSION = 1
+
+local exports = {}
+
+-- Cache of loaded external models: url -> { model, map, callbacks }
+local external_model_cache = {}
+
+--- Parse external model from msgpack data
+-- @param data raw msgpack data (rspamd_text, possibly zstd compressed)
+-- @return table with model data or nil, error message
+function exports.parse_model(data)
+ if not data then
+ return nil, "no data"
+ end
+
+ -- Convert rspamd_text to string if needed
+ local data_str
+ if type(data) == 'userdata' or (type(data) == 'table' and data.cookie) then
+ data_str = tostring(data)
+ else
+ data_str = data
+ end
+
+ -- Try zstd decompression first
+ local decompressed
+ local err, decompressed_data = rspamd_util.zstd_decompress(data_str)
+ if not err and decompressed_data then
+ decompressed = tostring(decompressed_data)
+ else
+ -- Assume uncompressed
+ decompressed = data_str
+ end
+
+ -- Parse msgpack
+ local parser = ucl.parser()
+ local ok, parse_err = parser:parse_text(decompressed, 'msgpack')
+ if not ok then
+ return nil, "failed to parse msgpack: " .. (parse_err or "unknown error")
+ end
+
+ local model = parser:get_object()
+
+ -- Validate model format
+ if model.magic ~= MODEL_MAGIC then
+ return nil, string.format("invalid magic: expected %s, got %s",
+ MODEL_MAGIC, model.magic or "nil")
+ end
+
+ if model.version ~= MODEL_FORMAT_VERSION then
+ return nil, string.format("unsupported model version: %s (expected %s)",
+ model.version or "nil", MODEL_FORMAT_VERSION)
+ end
+
+ return model
+end
+
+--- Load KANN from model data
+-- @param model parsed model table
+-- @return kann_t object or nil, error
+function exports.load_ann(model)
+ if not model.ann_data then
+ return nil, "no ann_data in model"
+ end
+
+ -- Decompress ann_data
+ local ann_data_str = model.ann_data
+ if type(ann_data_str) == 'userdata' or (type(ann_data_str) == 'table' and ann_data_str.cookie) then
+ ann_data_str = tostring(ann_data_str)
+ end
+
+ local err, ann_data = rspamd_util.zstd_decompress(ann_data_str)
+ if err then
+ return nil, "failed to decompress ann_data: " .. err
+ end
+
+ local ann = rspamd_kann.load(ann_data)
+ if not ann then
+ return nil, "failed to load KANN from model"
+ end
+
+ return ann
+end
+
+--- Load PCA from model data
+-- @param model parsed model table
+-- @return tensor or nil
+function exports.load_pca(model)
+ if not model.pca_data then
+ return nil
+ end
+
+ local pca_data_str = model.pca_data
+ if type(pca_data_str) == 'userdata' or (type(pca_data_str) == 'table' and pca_data_str.cookie) then
+ pca_data_str = tostring(pca_data_str)
+ end
+
+ local err, pca_data = rspamd_util.zstd_decompress(pca_data_str)
+ if err then
+ rspamd_logger.warnx(rspamd_config, "failed to decompress pca_data: %s", err)
+ return nil
+ end
+
+ local rspamd_tensor = require "rspamd_tensor"
+ return rspamd_tensor.load(pca_data)
+end
+
+--- Check if model is compatible with local config
+-- @param model parsed model table
+-- @param providers_digest local providers digest
+-- @return boolean, reason
+function exports.is_compatible(model, providers_digest)
+ if not model.providers_digest then
+ return false, "model has no providers_digest"
+ end
+
+ if model.providers_digest ~= providers_digest then
+ return false, string.format("providers digest mismatch: model=%s, local=%s",
+ model.providers_digest:sub(1, 8), providers_digest:sub(1, 8))
+ end
+
+ return true, "compatible"
+end
+
+--- Merge weights from external ANN into local ANN using interpolation
+-- w_new = alpha * w_external + (1-alpha) * w_local
+-- @param external_ann kann_t from external model
+-- @param local_ann kann_t from local training
+-- @param alpha weight for external (0.0 - 1.0)
+-- @return merged kann_t or nil, error
+function exports.merge_weights(external_ann, local_ann, alpha)
+ if not external_ann or not local_ann then
+ return nil, "missing ann"
+ end
+
+ alpha = alpha or 0.5
+
+ -- Check compatibility first
+ local ok = external_ann:is_compatible(local_ann)
+ if not ok then
+ return nil, "incompatible ANN architectures"
+ end
+
+ -- Use the external ANN as base and merge local weights into it
+ local merged, err = external_ann:merge_weights(local_ann, 1.0 - alpha)
+
+ if not merged then
+ return nil, "merge failed: " .. (err or "unknown")
+ end
+
+ return external_ann
+end
+
+--- Build URL for external model based on providers digest
+-- @param base_url base URL
+-- @param providers_digest digest of providers config
+-- @return full URL
+function exports.build_model_url(base_url, providers_digest)
+ -- Remove trailing slash from base_url
+ base_url = base_url:gsub('/+$', '')
+ return string.format("%s/%s", base_url, providers_digest)
+end
+
+--- Register external model as a map
+-- This uses the Maps infrastructure for HTTP loading with signature verification
+-- @param cfg rspamd_config
+-- @param rule neural rule configuration
+-- @param providers_digest digest of providers config
+-- @param on_load_callback function(model_data, err) called when model is loaded/reloaded
+-- @return boolean success
+function exports.register_model_map(cfg, rule, providers_digest, on_load_callback)
+ local ext_cfg = rule.external_model
+ if not ext_cfg or not ext_cfg.url then
+ return false
+ end
+
+ local url = ext_cfg.url
+ local cache_key = url
+
+ -- Check if already registered
+ if external_model_cache[cache_key] then
+ return true
+ end
+
+ -- Map callback: called when map data is loaded
+ local function map_callback(data, map)
+ if not data then
+ rspamd_logger.errx(cfg, 'external neural model map returned no data for %s', url)
+ if on_load_callback then
+ on_load_callback(nil, "no data from map")
+ end
+ return
+ end
+
+ -- Parse model
+ local model, parse_err = exports.parse_model(data)
+ if not model then
+ rspamd_logger.errx(cfg, 'failed to parse external neural model from %s: %s', url, parse_err)
+ if on_load_callback then
+ on_load_callback(nil, parse_err)
+ end
+ return
+ end
+
+ -- Check compatibility
+ local compatible, reason = exports.is_compatible(model, providers_digest)
+ if not compatible then
+ rspamd_logger.errx(cfg, 'external neural model incompatible: %s', reason)
+ if on_load_callback then
+ on_load_callback(nil, reason)
+ end
+ return
+ end
+
+ rspamd_logger.infox(cfg, 'loaded external neural model from %s (version=%s)',
+ url, model.model_version or 0)
+
+ -- Update cache
+ external_model_cache[cache_key] = {
+ model = model,
+ last_version = model.model_version,
+ last_load = os.time(),
+ }
+
+ -- Call user callback
+ if on_load_callback then
+ on_load_callback(model, nil)
+ end
+ end
+
+ -- Create callback map
+ local map = cfg:add_map({
+ url = url,
+ type = 'callback',
+ description = string.format('External neural model for rule %s', rule.prefix or 'default'),
+ callback = map_callback,
+ opaque_data = true, -- Get data as rspamd_text
+ })
+
+ if not map then
+ rspamd_logger.errx(cfg, 'failed to register external neural model map for %s', url)
+ return false
+ end
+
+ -- Set sign key if configured
+ if ext_cfg.sign_key then
+ map:set_sign_key(ext_cfg.sign_key)
+ end
+
+ external_model_cache[cache_key] = {
+ map = map,
+ callbacks = { on_load_callback },
+ }
+
+ return true
+end
+
+--- Get cached model data for URL
+-- @param url model URL
+-- @return cached model data or nil
+function exports.get_cached_model(url)
+ local cached = external_model_cache[url]
+ if cached and cached.model then
+ return cached.model
+ end
+ return nil
+end
+
+--- Create external model configuration for a neural rule
+-- @param rule neural rule configuration
+-- @param providers_digest digest of providers config
+-- @return table with external model config or nil
+function exports.create_external_config(rule, providers_digest)
+ local ext = rule.external_model
+ if not ext then
+ return nil
+ end
+
+ local url = ext.url
+ if not url then
+ -- Build URL from digest if base_url is provided
+ if ext.base_url then
+ url = exports.build_model_url(ext.base_url, providers_digest)
+ else
+ rspamd_logger.errx(rspamd_config, 'external_model requires url or base_url')
+ return nil
+ end
+ end
+
+ return {
+ url = url,
+ sign_key = ext.sign_key,
+ merge_strategy = ext.merge_strategy or "interpolate",
+ merge_alpha = ext.merge_alpha or 0.5,
+ check_interval = ext.check_interval or 86400, -- 24h
+ local_fine_tune = ext.local_fine_tune ~= false,
+ min_local_samples = ext.min_local_samples or 50,
+ providers_digest = providers_digest,
+ loaded = false,
+ last_version = nil,
+ last_check = nil,
+ }
+end
+
+--- Store external model metadata in Redis for later merge
+-- @param redis redis params
+-- @param ev_base event base
+-- @param ann_key Redis key for the ANN
+-- @param model_data parsed model data
+-- @param callback function(err)
+function exports.store_base_model(redis, ev_base, ann_key, model_data, callback)
+ -- Store base model version and compressed ann_data for re-merge
+ local base_key = ann_key .. "_base"
+
+ local function store_cb(err)
+ if err then
+ rspamd_logger.errx(rspamd_config, "failed to store base model: %s", err)
+ end
+ if callback then
+ callback(err)
+ end
+ end
+
+ -- Ensure ann_data is rspamd_text for opaque storage
+ local ann_data = model_data.ann_data
+ if type(ann_data) == 'string' then
+ -- Already compressed, convert to text
+ ann_data = rspamd_text.fromstring(ann_data)
+ end
+
+ -- Store base version and ann_data
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ redis,
+ nil,
+ true, -- is write
+ store_cb,
+ 'HMSET',
+ {
+ base_key,
+ 'version', tostring(model_data.model_version or 0),
+ 'ann_data', ann_data,
+ 'providers_digest', model_data.providers_digest or '',
+ 'created_at', tostring(model_data.created_at or os.time()),
+ },
+ { opaque_data = true }
+ )
+end
+
+--- Load base model from Redis for re-merge
+-- @param redis redis params
+-- @param ev_base event base
+-- @param ann_key Redis key for the ANN
+-- @param callback function(err, model_data)
+function exports.load_base_model(redis, ev_base, ann_key, callback)
+ local base_key = ann_key .. "_base"
+
+ local function load_cb(err, data)
+ if err then
+ callback(err, nil)
+ return
+ end
+
+ if type(data) ~= 'table' then
+ callback("no base model found", nil)
+ return
+ end
+
+ local model_data = {
+ model_version = tonumber(data[1]) or 0,
+ ann_data = data[2], -- rspamd_text
+ providers_digest = data[3],
+ created_at = tonumber(data[4]) or 0,
+ }
+
+ callback(nil, model_data)
+ end
+
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ redis,
+ nil,
+ false, -- is write
+ load_cb,
+ 'HMGET',
+ { base_key, 'version', 'ann_data', 'providers_digest', 'created_at' },
+ { opaque_data = true }
+ )
+end
+
+--- Serialize model to msgpack format
+-- @param ann kann_t object
+-- @param pca optional PCA tensor
+-- @param providers_digest providers config digest
+-- @param opts optional { norm_stats, roc_thresholds }
+-- @return rspamd_text with compressed msgpack data
+function exports.serialize_model(ann, pca, providers_digest, opts)
+ opts = opts or {}
+
+ -- Save ANN to memory
+ local ann_text = ann:save()
+ local ann_compressed = rspamd_util.zstd_compress(ann_text)
+
+ local model = {
+ magic = MODEL_MAGIC,
+ version = MODEL_FORMAT_VERSION,
+ model_version = opts.model_version or 1,
+ providers_digest = providers_digest,
+ ann_data = ann_compressed,
+ created_at = os.time(),
+ }
+
+ if pca then
+ local pca_text = pca:save()
+ model.pca_data = rspamd_util.zstd_compress(pca_text)
+ end
+
+ if opts.norm_stats then
+ model.norm_stats = opts.norm_stats
+ end
+
+ if opts.roc_thresholds then
+ model.roc_thresholds = opts.roc_thresholds
+ end
+
+ return ucl.to_format(model, 'msgpack')
+end
+
+exports.MODEL_MAGIC = MODEL_MAGIC
+exports.MODEL_FORMAT_VERSION = MODEL_FORMAT_VERSION
+
+return exports
local rspamd_tensor = require "rspamd_tensor"
local rspamd_util = require "rspamd_util"
local ucl = require "ucl"
+local neural_external = require "lua_neural_external"
local N = 'neural'
per_provider_pca = false, -- if true, apply PCA per provider before fusion (not active yet)
},
disable_symbols_input = false, -- when true, do not use symbols provider unless explicitly listed
+ -- External pretrained model support
+ external_model = nil, -- external model configuration (see lua_neural_external)
}
-- Rule structure:
result_to_vector = result_to_vector,
settings = settings,
spawn_train = spawn_train,
+ -- External model support
+ neural_external = neural_external,
}
--- /dev/null
+--[[
+Copyright (c) 2026, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[
+Export neural network model from Redis to RNM1 format for distribution.
+
+This tool extracts a trained neural model from Redis and saves it as
+a portable msgpack file that can be loaded by other rspamd instances.
+
+Usage:
+ rspamadm neural_export -o model.rnm
+
+Options:
+ -c, --config Path to rspamd config (to get Redis connection)
+ -r, --rule Rule name to export (default: 'default')
+ -s, --settings Settings ID to export (default: 'default')
+ -o, --output Output file path (required)
+ -v, --version Specific model version to export (default: latest)
+ --list List available models instead of exporting
+ --digest Also print providers digest for URL construction
+]]--
+
+local rspamd_logger = require "rspamd_logger"
+local lua_redis = require "lua_redis"
+local argparse = require "argparse"
+local ucl = require "ucl"
+local rspamd_util = require "rspamd_util"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_tensor = require "rspamd_tensor"
+
+local parser = argparse()
+ :name "rspamadm neural_export"
+ :description "Export neural network model from Redis to RNM1 format"
+ :help_description_margin(32)
+
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:option "-r --rule"
+ :description "Rule name to export"
+ :argname("<rule>")
+ :default("default")
+parser:option "-s --settings"
+ :description "Settings ID to export"
+ :argname("<id>")
+ :default("default")
+parser:option "-o --output"
+ :description "Output file path"
+ :argname("<file>")
+parser:option "-v --version"
+ :description "Specific model version to export"
+ :argname("<version>")
+ :convert(tonumber)
+parser:flag "--list"
+ :description "List available models instead of exporting"
+parser:flag "--digest"
+ :description "Print providers digest for URL construction"
+
+-- Model format constants (must match lua_neural_external)
+local MODEL_MAGIC = "RNM1"
+local MODEL_FORMAT_VERSION = 1
+
+--- Load config and initialize Redis connection
+local function init_redis(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ return nil
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ return nil
+ end
+
+ local redis_params = lua_redis.parse_redis_server('neural')
+ if not redis_params then
+ rspamd_logger.errx('cannot get Redis configuration for neural module')
+ return nil
+ end
+
+ return redis_params
+end
+
+--- Get neural rule configuration
+local function get_neural_rule(opts)
+ local neural_opts = rspamd_config:get_all_opt('neural')
+ if not neural_opts then
+ rspamd_logger.errx('no neural configuration found')
+ return nil
+ end
+
+ local rules = neural_opts['rules'] or {}
+ local rule_name = opts['rule']
+
+ -- Handle legacy config (neural_opts itself is the rule)
+ if not rules['default'] and neural_opts.train then
+ rules['default'] = neural_opts
+ end
+
+ local rule = rules[rule_name]
+ if not rule then
+ rspamd_logger.errx('rule "%s" not found in neural config', rule_name)
+ return nil
+ end
+
+ return rule
+end
+
+--- Build Redis prefix for a rule/settings combination
+local function build_redis_prefix(rule, settings_name)
+ local neural_common = require "plugins/neural"
+ return neural_common.redis_ann_prefix({
+ prefix = rule.prefix or 'default'
+ }, settings_name)
+end
+
+--- List available models from Redis
+local function list_models(redis_params, prefix, callback)
+ local function members_cb(err, data)
+ if err then
+ callback(err, nil)
+ else
+ callback(nil, data)
+ end
+ end
+
+ -- Get all profiles from sorted set (most recent first)
+ lua_redis.redis_make_request_taskless(nil,
+ rspamd_config,
+ redis_params,
+ nil,
+ false,
+ members_cb,
+ 'ZREVRANGE',
+ { prefix, '0', '-1', 'WITHSCORES' }
+ )
+end
+
+--- Load ANN data from Redis
+local function load_ann_from_redis(redis_params, ann_key, callback)
+ local function data_cb(err, data)
+ if err then
+ callback(err, nil)
+ return
+ end
+
+ if type(data) ~= 'table' then
+ callback("no data found at key: " .. ann_key, nil)
+ return
+ end
+
+ -- data[1] = ann, data[2] = roc_thresholds, data[3] = pca,
+ -- data[4] = providers_meta, data[5] = norm_stats
+ local result = {
+ ann_data = data[1],
+ roc_thresholds = data[2],
+ pca_data = data[3],
+ providers_meta = data[4],
+ norm_stats = data[5],
+ }
+
+ -- Decompress ann_data
+ if result.ann_data then
+ local dec_err, dec_data = rspamd_util.zstd_decompress(result.ann_data)
+ if not dec_err and dec_data then
+ result.ann = rspamd_kann.load(dec_data)
+ end
+ end
+
+ -- Decompress and load PCA
+ if result.pca_data then
+ local dec_err, dec_data = rspamd_util.zstd_decompress(result.pca_data)
+ if not dec_err and dec_data then
+ result.pca = rspamd_tensor.load(dec_data)
+ end
+ end
+
+ -- Parse JSON fields
+ if result.roc_thresholds then
+ local roc_parser = ucl.parser()
+ local ok = roc_parser:parse_text(result.roc_thresholds)
+ if ok then
+ result.roc_thresholds = roc_parser:get_object()
+ end
+ end
+
+ if result.norm_stats then
+ local norm_parser = ucl.parser()
+ local ok = norm_parser:parse_text(result.norm_stats)
+ if ok then
+ result.norm_stats = norm_parser:get_object()
+ end
+ end
+
+ callback(nil, result)
+ end
+
+ lua_redis.redis_make_request_taskless(nil,
+ rspamd_config,
+ redis_params,
+ nil,
+ false,
+ data_cb,
+ 'HMGET',
+ { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' },
+ { opaque_data = true }
+ )
+end
+
+--- Load profile from Redis sorted set
+local function load_profile(profile_str)
+ local profile_parser = ucl.parser()
+ local ok = profile_parser:parse_string(profile_str)
+ if not ok then
+ return nil
+ end
+ return profile_parser:get_object()
+end
+
+--- Serialize model to RNM1 format
+local function serialize_model(ann, pca, providers_digest, opts)
+ opts = opts or {}
+
+ -- Save ANN to memory
+ local ann_text = ann:save()
+ local ann_compressed = rspamd_util.zstd_compress(ann_text)
+
+ local model = {
+ magic = MODEL_MAGIC,
+ version = MODEL_FORMAT_VERSION,
+ model_version = opts.model_version or 1,
+ providers_digest = providers_digest,
+ ann_data = ann_compressed,
+ created_at = os.time(),
+ }
+
+ if pca then
+ local pca_text = pca:save()
+ model.pca_data = rspamd_util.zstd_compress(pca_text)
+ end
+
+ if opts.norm_stats then
+ model.norm_stats = opts.norm_stats
+ end
+
+ if opts.roc_thresholds then
+ model.roc_thresholds = opts.roc_thresholds
+ end
+
+ return ucl.to_format(model, 'msgpack')
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ -- Initialize
+ local redis_params = init_redis(opts)
+ if not redis_params then
+ os.exit(1)
+ end
+
+ local rule = get_neural_rule(opts)
+ if not rule then
+ os.exit(1)
+ end
+
+ -- Get providers digest
+ local neural_common = require "plugins/neural"
+ local providers_digest = neural_common.providers_config_digest(rule.providers, rule)
+
+ -- Print digest if requested
+ if opts['digest'] then
+ print(string.format("Providers digest: %s", providers_digest))
+ if not opts['list'] and not opts['output'] then
+ return
+ end
+ end
+
+ local settings_name = opts['settings']
+ local prefix = build_redis_prefix(rule, settings_name)
+
+ -- List mode: show available models
+ if opts['list'] then
+ print(string.format("\nAvailable models for rule '%s', settings '%s':", opts['rule'], settings_name))
+ print(string.format("Redis prefix: %s", prefix))
+ print("")
+
+ local co = coroutine.create(function()
+ list_models(redis_params, prefix, function(err, data)
+ if err then
+ rspamd_logger.errx('failed to list models: %s', err)
+ return
+ end
+
+ if not data or #data == 0 then
+ print("No models found")
+ return
+ end
+
+ -- Parse profiles (data is alternating: profile, score, profile, score, ...)
+ for i = 1, #data, 2 do
+ local profile_str = data[i]
+ local score = tonumber(data[i + 1]) or 0
+ local profile = load_profile(profile_str)
+
+ if profile then
+ local ts = os.date("%Y-%m-%d %H:%M:%S", score)
+ print(string.format(" Version: %s, Key: %s, Updated: %s",
+ profile.version or 0,
+ profile.redis_key or "?",
+ ts))
+ end
+ end
+ end)
+ end)
+ coroutine.resume(co)
+ return
+ end
+
+ -- Export mode: need output file
+ if not opts['output'] then
+ rspamd_logger.errx('output file is required for export (use -o <file>)')
+ os.exit(1)
+ end
+
+ local output_file = opts['output']
+ local target_version = opts['version']
+
+ -- Find and load the model
+ local co = coroutine.create(function()
+ list_models(redis_params, prefix, function(err, data)
+ if err then
+ rspamd_logger.errx('failed to list models: %s', err)
+ return
+ end
+
+ if not data or #data == 0 then
+ rspamd_logger.errx('no models found in Redis')
+ return
+ end
+
+ -- Find the target version (or use latest)
+ local selected_profile
+ local selected_score = 0
+
+ for i = 1, #data, 2 do
+ local profile_str = data[i]
+ local score = tonumber(data[i + 1]) or 0
+ local profile = load_profile(profile_str)
+
+ if profile then
+ if target_version then
+ -- Looking for specific version
+ if profile.version == target_version then
+ selected_profile = profile
+ selected_score = score
+ break
+ end
+ else
+ -- Use latest (highest score in sorted set = most recent)
+ if not selected_profile or score > selected_score then
+ selected_profile = profile
+ selected_score = score
+ end
+ end
+ end
+ end
+
+ if not selected_profile then
+ if target_version then
+ rspamd_logger.errx('model version %s not found', target_version)
+ else
+ rspamd_logger.errx('no suitable model found')
+ end
+ return
+ end
+
+ local ann_key = selected_profile.redis_key
+ rspamd_logger.messagex('Loading model from key: %s (version %s)',
+ ann_key, selected_profile.version or 0)
+
+ -- Load ANN data
+ load_ann_from_redis(redis_params, ann_key, function(load_err, model_data)
+ if load_err then
+ rspamd_logger.errx('failed to load model data: %s', load_err)
+ return
+ end
+
+ if not model_data.ann then
+ rspamd_logger.errx('failed to load ANN from Redis')
+ return
+ end
+
+ -- Serialize to RNM1 format
+ local rnm_data = serialize_model(model_data.ann, model_data.pca, providers_digest, {
+ model_version = selected_profile.version or 1,
+ norm_stats = model_data.norm_stats,
+ roc_thresholds = model_data.roc_thresholds,
+ })
+
+ -- Write to file
+ local out = assert(io.open(output_file, "wb"))
+ out:write(rnm_data)
+ out:close()
+
+ rspamd_logger.messagex('Exported model to: %s', output_file)
+ rspamd_logger.messagex(' Model version: %s', selected_profile.version or 0)
+ rspamd_logger.messagex(' Providers digest: %s', providers_digest:sub(1, 16) .. '...')
+ rspamd_logger.messagex(' File size: %s bytes', #rnm_data)
+ end)
+ end)
+ end)
+ coroutine.resume(co)
+end
+
+return {
+ name = "neural_export",
+ aliases = { "neural_export" },
+ handler = handler,
+ description = parser._description
+}
LUA_FUNCTION_DEF(kann, save);
LUA_FUNCTION_DEF(kann, train1);
LUA_FUNCTION_DEF(kann, apply1);
+LUA_FUNCTION_DEF(kann, merge_weights);
+LUA_FUNCTION_DEF(kann, is_compatible);
static luaL_reg rspamd_kann_m[] = {
LUA_INTERFACE_DEF(kann, save),
LUA_INTERFACE_DEF(kann, train1),
LUA_INTERFACE_DEF(kann, apply1),
+ LUA_INTERFACE_DEF(kann, merge_weights),
+ LUA_INTERFACE_DEF(kann, is_compatible),
{"__gc", lua_kann_destroy},
{NULL, NULL},
};
}
return 1;
+}
+
+/***
+ * @function kann:merge_weights(other_ann, alpha)
+ * Merge weights from another ANN into this one using linear interpolation.
+ * w_new = (1 - alpha) * w_self + alpha * w_other
+ * @param {kann} other_ann source ANN to merge from
+ * @param {number} alpha weight for source ANN (0.0 - 1.0)
+ * @return {boolean} true on success, false on error
+ */
+static int
+lua_kann_merge_weights(lua_State *L)
+{
+ kann_t *self = lua_check_kann(L, 1);
+ kann_t *other = lua_check_kann(L, 2);
+ double alpha = luaL_checknumber(L, 3);
+
+ if (self && other) {
+ if (alpha < 0.0 || alpha > 1.0) {
+ return luaL_error(L, "alpha must be between 0.0 and 1.0, got %f", alpha);
+ }
+
+ /* Check compatibility first */
+ if (!kann_is_compatible(self, other)) {
+ lua_pushboolean(L, false);
+ lua_pushstring(L, "incompatible ANN architectures");
+ return 2;
+ }
+
+ int ret = kann_merge_weights(self, other, (float) alpha);
+
+ if (ret == 0) {
+ lua_pushboolean(L, true);
+ return 1;
+ }
+ else {
+ lua_pushboolean(L, false);
+ lua_pushstring(L, "merge failed");
+ return 2;
+ }
+ }
+ else {
+ return luaL_error(L, "invalid arguments: two kann objects required");
+ }
+}
+
+/***
+ * @function kann:is_compatible(other_ann)
+ * Check if two ANNs have compatible architecture for weight merging.
+ * @param {kann} other_ann ANN to check compatibility with
+ * @return {boolean} true if compatible, false otherwise
+ */
+static int
+lua_kann_is_compatible(lua_State *L)
+{
+ kann_t *self = lua_check_kann(L, 1);
+ kann_t *other = lua_check_kann(L, 2);
+
+ if (self && other) {
+ int ret = kann_is_compatible(self, other);
+ lua_pushboolean(L, ret == 1);
+ return 1;
+ }
+ else {
+ return luaL_error(L, "invalid arguments: two kann objects required");
+ }
}
\ No newline at end of file
local lua_verdict = require "lua_verdict"
local neural_common = require "plugins/neural"
local neural_learn = require "lua_neural_learn"
+local neural_external = require "lua_neural_external"
local rspamd_kann = require "rspamd_kann"
local rspamd_logger = require "rspamd_logger"
local rspamd_tensor = require "rspamd_tensor"
)
end
+--- External model support functions
+
+-- Apply loaded external model to settings element
+-- @param rule neural rule configuration
+-- @param set settings element
+-- @param model parsed external model data
+-- @param ev_base event base (optional, for storing base model)
+local function apply_external_model(rule, set, model, ev_base)
+ local ext_cfg = rule.external_model
+ if not ext_cfg or not model then
+ return false
+ end
+
+ -- Load external ANN
+ local ext_ann, ann_err = neural_external.load_ann(model)
+ if not ext_ann then
+ rspamd_logger.errx(rspamd_config, 'failed to load external ANN for %s:%s: %s',
+ rule.prefix, set.name, ann_err or "unknown")
+ return false
+ end
+
+ -- Check if we have a local ANN to merge with
+ if set.ann and set.ann.ann then
+ -- Check architecture compatibility
+ local ok = ext_ann:is_compatible(set.ann.ann)
+ if not ok then
+ rspamd_logger.warnx(rspamd_config,
+ 'external ANN architecture incompatible with local ANN for %s:%s, using external only',
+ rule.prefix, set.name)
+ set.ann.ann = ext_ann
+ set.ann.version = model.model_version or 1
+ set.ann.external_version = model.model_version
+ set.ann.external_source = ext_cfg.url
+ return true
+ end
+
+ -- Merge weights
+ local alpha = ext_cfg.merge_alpha or 0.5
+ local merged, merge_err = ext_ann:merge_weights(set.ann.ann, alpha)
+ if not merged then
+ rspamd_logger.errx(rspamd_config, 'failed to merge ANNs for %s:%s: %s',
+ rule.prefix, set.name, merge_err or "unknown")
+ return false
+ end
+
+ rspamd_logger.infox(rspamd_config,
+ 'merged external model (version=%s, alpha=%s) with local ANN for %s:%s',
+ model.model_version, alpha, rule.prefix, set.name)
+
+ -- Update ANN reference
+ set.ann.ann = merged
+ set.ann.version = (set.ann.version or 0) + 1
+ set.ann.external_version = model.model_version
+ set.ann.external_source = ext_cfg.url
+
+ -- Store base model for future re-merge
+ if ev_base then
+ neural_external.store_base_model(rule.redis, ev_base, set.ann.redis_key, model, function(store_err)
+ if store_err then
+ rspamd_logger.warnx(rspamd_config, 'failed to store base model: %s', store_err)
+ end
+ end)
+ end
+ else
+ -- No local ANN, just use external
+ rspamd_logger.infox(rspamd_config,
+ 'loaded external model (version=%s) as initial ANN for %s:%s',
+ model.model_version, rule.prefix, set.name)
+
+ set.ann = {
+ version = model.model_version or 1,
+ redis_key = neural_common.new_ann_key(rule, set, model.model_version or 1),
+ external_version = model.model_version,
+ external_source = ext_cfg.url,
+ ann = ext_ann,
+ providers_digest = ext_cfg.providers_digest,
+ }
+
+ -- Store base model for future re-merge
+ if ev_base then
+ neural_external.store_base_model(rule.redis, ev_base, set.ann.redis_key, model, function(store_err)
+ if store_err then
+ rspamd_logger.warnx(rspamd_config, 'failed to store base model: %s', store_err)
+ end
+ end)
+ end
+ end
+
+ -- Load PCA if present
+ local pca = neural_external.load_pca(model)
+ if pca then
+ set.ann.pca = pca
+ end
+
+ -- Copy normalization stats
+ if model.norm_stats then
+ set.ann.norm_stats = model.norm_stats
+ end
+
+ -- Copy ROC thresholds
+ if model.roc_thresholds then
+ set.ann.roc_thresholds = model.roc_thresholds
+ end
+
+ -- Update external model state
+ ext_cfg.last_version = model.model_version
+ ext_cfg.loaded = true
+
+ return true
+end
+
+-- Register external model map for a rule
+-- This should be called at config time
+-- @param rule neural rule configuration
+-- @return boolean success
+local function register_external_model_map(rule)
+ local ext_cfg = rule.external_model
+ if not ext_cfg or not ext_cfg.url then
+ return false
+ end
+
+ -- Store rule reference for callbacks
+ local rule_ref = rule
+
+ -- Map callback: called when external model is loaded/reloaded
+ local function on_model_load(model, err)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'external model load failed for %s: %s',
+ rule_ref.prefix, err)
+ return
+ end
+
+ -- Apply model to all settings
+ for _, set in pairs(rule_ref.settings) do
+ if type(set) == 'table' then
+ apply_external_model(rule_ref, set, model, nil)
+ end
+ end
+ end
+
+ return neural_external.register_model_map(rspamd_config, rule, ext_cfg.providers_digest, on_model_load)
+end
+
+-- Check external model updates (called periodically by map infrastructure)
+-- This is now mostly handled by the map's automatic reload mechanism
+local function check_external_model(worker, cfg, ev_base, rule)
+ local ext_cfg = rule.external_model
+ if not ext_cfg then
+ return
+ end
+
+ -- Check if we have a cached model from the map
+ local cached_model = neural_external.get_cached_model(ext_cfg.url)
+ if cached_model and cached_model.model_version ~= ext_cfg.last_version then
+ rspamd_logger.infox(cfg, 'external model updated for %s: version %s -> %s',
+ rule.prefix, ext_cfg.last_version or 0, cached_model.model_version)
+
+ -- Apply to all settings
+ for _, set in pairs(rule.settings) do
+ if type(set) == 'table' then
+ apply_external_model(rule, set, cached_model, ev_base)
+ end
+ end
+ end
+end
+
-- Used to check an element in Redis serialized as JSON
-- for some specific rule + some specific setting
-- This function tries to load more fresh or more specific ANNs in lieu of
end
end
+ -- External model configuration
+ if rule_elt.external_model then
+ local providers_digest = neural_common.providers_config_digest(rule_elt.providers, rule_elt)
+ rule_elt.external_model = neural_external.create_external_config(rule_elt, providers_digest)
+ if rule_elt.external_model then
+ rspamd_logger.infox(rspamd_config, "configured external model for rule %s: url=%s, merge_alpha=%s",
+ k, rule_elt.external_model.url, rule_elt.external_model.merge_alpha)
+ end
+ end
+
rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
settings.rules[k] = rule_elt
+
+ -- Register external model map if configured
+ if rule_elt.external_model then
+ register_external_model_map(rule_elt)
+ end
+
rspamd_config:set_metric_symbol({
name = rule_elt.symbol_spam,
score = 0.0,
function(_, _)
-- Clean old ANNs
cleanup_anns(rule, cfg, ev_base)
+ -- Check for external model updates
+ check_external_model(worker, cfg, ev_base, rule)
return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
'try_train_ann')
end)