]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add external pretrained neural model support
authorVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 6 Mar 2026 10:49:43 +0000 (10:49 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 6 Mar 2026 11:27:41 +0000 (11:27 +0000)
This commit adds the ability to load pretrained neural network models
from external sources (HTTP/HTTPS) and merge them with locally trained
weights. Users can receive a pretrained model and fine-tune it with
their own data.

Model format (msgpack with magic "RNM1"):
- magic: format identifier
- version: format version (currently 1)
- model_version: model training version
- providers_digest: must match local providers config
- ann_data: serialized KANN (zstd compressed)
- pca_data: optional PCA matrix
- norm_stats, roc_thresholds: optional metadata

Key changes:
- lualib/lua_neural_external.lua: new module for external model handling
  - Model parsing, KANN loading, weight merging via interpolation
  - Map-based loading with signature verification support
  - Base model storage in Redis for future re-merge

- contrib/kann/kann.c: add kann_merge_weights() for weight interpolation
  - w_new = alpha * w_external + (1-alpha) * w_local
  - kann_is_compatible() for architecture compatibility check

- src/lua/lua_kann.c: Lua bindings for merge_weights and is_compatible

- Neural plugin integration:
  - Register external model as callback map at config time
  - Apply loaded model to all settings elements
  - Automatic update checking via map infrastructure

Configuration example:
  neural {
    rules {
      spam_filter = {
        external_model = {
          url = "https://your-provider.com/models/<digest>";
          sign_key = "your_key";
          merge_alpha = 0.6;  # 60% external, 40% local
        };
      };
    };
  }

contrib/kann/kann.c
contrib/kann/kann.h
lualib/lua_neural_external.lua [new file with mode: 0644]
lualib/plugins/neural.lua
lualib/rspamadm/neural_export.lua [new file with mode: 0644]
src/lua/lua_kann.c
src/plugins/lua/neural.lua

index 658f98a441b265ebff076ef5a7f87dc6a14a4ae4..86723bd9d350f455721a70dbba4801ca39495ae2 100644 (file)
@@ -562,6 +562,87 @@ kann_t *kann_load(const char *fn)
        return ann;
 }
 
+/*******************************
+ *** @@MERGE: weight merging ***
+ ********************************/
+
+/**
+ * Merge weights from source ANN into destination ANN using linear interpolation.
+ * w_dst = (1 - alpha) * w_dst + alpha * w_src
+ *
+ * Both ANNs must have the same architecture (same number of variables and constants).
+ *
+ * @param dst   destination ANN (modified in place)
+ * @param src   source ANN (not modified)
+ * @param alpha weight for source ANN (0.0 - 1.0)
+ * @return 0 on success, -1 on error
+ */
+int kann_merge_weights(kann_t *dst, const kann_t *src, float alpha)
+{
+       int n_var_dst, n_var_src;
+       int n_const_dst, n_const_src;
+
+       if (!dst || !src) {
+               return -1;
+       }
+
+       n_var_dst = kad_size_var(dst->n, dst->v);
+       n_var_src = kad_size_var(src->n, src->v);
+       n_const_dst = kad_size_const(dst->n, dst->v);
+       n_const_src = kad_size_const(src->n, src->v);
+
+       /* Check architecture compatibility */
+       if (n_var_dst != n_var_src) {
+               return -1;
+       }
+
+       if (n_const_dst != n_const_src) {
+               /* Constants should match, but we can still merge weights */
+               (void) n_const_dst; /* suppress unused warning */
+       }
+
+       /* Merge variable weights (trainable parameters) */
+       for (int i = 0; i < n_var_dst; i++) {
+               dst->x[i] = (1.0f - alpha) * dst->x[i] + alpha * src->x[i];
+       }
+
+       return 0;
+}
+
+/**
+ * Check if two ANNs have compatible architecture for weight merging.
+ *
+ * @param a first ANN
+ * @param b second ANN
+ * @return 1 if compatible, 0 otherwise
+ */
+int kann_is_compatible(const kann_t *a, const kann_t *b)
+{
+       int n_var_a, n_var_b;
+
+       if (!a || !b) {
+               return 0;
+       }
+
+       n_var_a = kad_size_var(a->n, a->v);
+       n_var_b = kad_size_var(b->n, b->v);
+
+       if (n_var_a != n_var_b) {
+               return 0;
+       }
+
+       /* Check input/output dimensions */
+       if (kann_dim_in(a) != kann_dim_in(b)) {
+               return 0;
+       }
+
+       if (kann_dim_out(a) != kann_dim_out(b)) {
+               return 0;
+       }
+
+       return 1;
+}
+
 /**********************************************
  *** @@LAYER: layers and model generation ***
  **********************************************/
index af0de5fbac101a4c37bf681819f89307c6355654..313c9e9df10aa48702bea736abe8029a4a1ee320 100644 (file)
@@ -233,6 +233,10 @@ void kann_save(const char *fn, kann_t *ann);
 kann_t *kann_load_fp(FILE *fp);
 kann_t *kann_load(const char *fn);
 
+/* weight merging */
+int kann_merge_weights(kann_t *dst, const kann_t *src, float alpha);
+int kann_is_compatible(const kann_t *a, const kann_t *b);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/lualib/lua_neural_external.lua b/lualib/lua_neural_external.lua
new file mode 100644 (file)
index 0000000..0330c9a
--- /dev/null
@@ -0,0 +1,484 @@
+--[[
+Copyright (c) 2026, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[
+External neural model loading and merging.
+
+This module provides functionality to load pretrained neural models
+from external sources (HTTP/HTTPS) via the Maps infrastructure
+and merge them with locally trained weights.
+
+Model format (msgpack):
+{
+  magic = "RNM1",           -- Rspamd Neural Model v1
+  version = 1,              -- format version
+  model_version = 123,      -- model training version (incremented on retrain)
+  providers_digest = "...", -- digest of providers config (must match local)
+  ann_data = "...",         -- serialized KANN (zstd compressed)
+  pca_data = "...",         -- optional PCA (zstd compressed)
+  norm_stats = {...},       -- normalization stats
+  roc_thresholds = {...},   -- ROC thresholds
+  created_at = timestamp,
+}
+
+Usage in neural config:
+  external_model = {
+    url = "https://your-provider.com/models/<digest>";
+    sign_key = "your_key";  -- optional signature verification
+    merge_alpha = 0.6;      -- 60% external, 40% local
+  };
+]]--
+
+local lua_redis = require "lua_redis"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_logger = require "rspamd_logger"
+local rspamd_util = require "rspamd_util"
+local rspamd_text = require "rspamd_text"
+local ucl = require "ucl"
+
+-- Model format constants
+local MODEL_MAGIC = "RNM1"
+local MODEL_FORMAT_VERSION = 1
+
+local exports = {}
+
+-- Cache of loaded external models: url -> { model, map, callbacks }
+local external_model_cache = {}
+
+--- Parse external model from msgpack data
+-- @param data raw msgpack data (rspamd_text, possibly zstd compressed)
+-- @return table with model data or nil, error message
+function exports.parse_model(data)
+  if not data then
+    return nil, "no data"
+  end
+
+  -- Convert rspamd_text to string if needed
+  local data_str
+  if type(data) == 'userdata' or (type(data) == 'table' and data.cookie) then
+    data_str = tostring(data)
+  else
+    data_str = data
+  end
+
+  -- Try zstd decompression first
+  local decompressed
+  local err, decompressed_data = rspamd_util.zstd_decompress(data_str)
+  if not err and decompressed_data then
+    decompressed = tostring(decompressed_data)
+  else
+    -- Assume uncompressed
+    decompressed = data_str
+  end
+
+  -- Parse msgpack
+  local parser = ucl.parser()
+  local ok, parse_err = parser:parse_text(decompressed, 'msgpack')
+  if not ok then
+    return nil, "failed to parse msgpack: " .. (parse_err or "unknown error")
+  end
+
+  local model = parser:get_object()
+
+  -- Validate model format
+  if model.magic ~= MODEL_MAGIC then
+    return nil, string.format("invalid magic: expected %s, got %s",
+      MODEL_MAGIC, model.magic or "nil")
+  end
+
+  if model.version ~= MODEL_FORMAT_VERSION then
+    return nil, string.format("unsupported model version: %s (expected %s)",
+      model.version or "nil", MODEL_FORMAT_VERSION)
+  end
+
+  return model
+end
+
+--- Load KANN from model data
+-- @param model parsed model table
+-- @return kann_t object or nil, error
+function exports.load_ann(model)
+  if not model.ann_data then
+    return nil, "no ann_data in model"
+  end
+
+  -- Decompress ann_data
+  local ann_data_str = model.ann_data
+  if type(ann_data_str) == 'userdata' or (type(ann_data_str) == 'table' and ann_data_str.cookie) then
+    ann_data_str = tostring(ann_data_str)
+  end
+
+  local err, ann_data = rspamd_util.zstd_decompress(ann_data_str)
+  if err then
+    return nil, "failed to decompress ann_data: " .. err
+  end
+
+  local ann = rspamd_kann.load(ann_data)
+  if not ann then
+    return nil, "failed to load KANN from model"
+  end
+
+  return ann
+end
+
+--- Load PCA from model data
+-- @param model parsed model table
+-- @return tensor or nil
+function exports.load_pca(model)
+  if not model.pca_data then
+    return nil
+  end
+
+  local pca_data_str = model.pca_data
+  if type(pca_data_str) == 'userdata' or (type(pca_data_str) == 'table' and pca_data_str.cookie) then
+    pca_data_str = tostring(pca_data_str)
+  end
+
+  local err, pca_data = rspamd_util.zstd_decompress(pca_data_str)
+  if err then
+    rspamd_logger.warnx(rspamd_config, "failed to decompress pca_data: %s", err)
+    return nil
+  end
+
+  local rspamd_tensor = require "rspamd_tensor"
+  return rspamd_tensor.load(pca_data)
+end
+
+--- Check if model is compatible with local config
+-- @param model parsed model table
+-- @param providers_digest local providers digest
+-- @return boolean, reason
+function exports.is_compatible(model, providers_digest)
+  if not model.providers_digest then
+    return false, "model has no providers_digest"
+  end
+
+  if model.providers_digest ~= providers_digest then
+    return false, string.format("providers digest mismatch: model=%s, local=%s",
+      model.providers_digest:sub(1, 8), providers_digest:sub(1, 8))
+  end
+
+  return true, "compatible"
+end
+
+--- Merge weights from external ANN into local ANN using interpolation
+-- w_new = alpha * w_external + (1-alpha) * w_local
+-- @param external_ann kann_t from external model
+-- @param local_ann kann_t from local training
+-- @param alpha weight for external (0.0 - 1.0)
+-- @return merged kann_t or nil, error
+function exports.merge_weights(external_ann, local_ann, alpha)
+  if not external_ann or not local_ann then
+    return nil, "missing ann"
+  end
+
+  alpha = alpha or 0.5
+
+  -- Check compatibility first
+  local ok = external_ann:is_compatible(local_ann)
+  if not ok then
+    return nil, "incompatible ANN architectures"
+  end
+
+  -- Use the external ANN as base and merge local weights into it
+  local merged, err = external_ann:merge_weights(local_ann, 1.0 - alpha)
+
+  if not merged then
+    return nil, "merge failed: " .. (err or "unknown")
+  end
+
+  return external_ann
+end
+
+--- Build URL for external model based on providers digest
+-- @param base_url base URL
+-- @param providers_digest digest of providers config
+-- @return full URL
+function exports.build_model_url(base_url, providers_digest)
+  -- Remove trailing slash from base_url
+  base_url = base_url:gsub('/+$', '')
+  return string.format("%s/%s", base_url, providers_digest)
+end
+
+--- Register external model as a map
+-- This uses the Maps infrastructure for HTTP loading with signature verification
+-- @param cfg rspamd_config
+-- @param rule neural rule configuration
+-- @param providers_digest digest of providers config
+-- @param on_load_callback function(model_data, err) called when model is loaded/reloaded
+-- @return boolean success
+function exports.register_model_map(cfg, rule, providers_digest, on_load_callback)
+  local ext_cfg = rule.external_model
+  if not ext_cfg or not ext_cfg.url then
+    return false
+  end
+
+  local url = ext_cfg.url
+  local cache_key = url
+
+  -- Check if already registered
+  if external_model_cache[cache_key] then
+    return true
+  end
+
+  -- Map callback: called when map data is loaded
+  local function map_callback(data, map)
+    if not data then
+      rspamd_logger.errx(cfg, 'external neural model map returned no data for %s', url)
+      if on_load_callback then
+        on_load_callback(nil, "no data from map")
+      end
+      return
+    end
+
+    -- Parse model
+    local model, parse_err = exports.parse_model(data)
+    if not model then
+      rspamd_logger.errx(cfg, 'failed to parse external neural model from %s: %s', url, parse_err)
+      if on_load_callback then
+        on_load_callback(nil, parse_err)
+      end
+      return
+    end
+
+    -- Check compatibility
+    local compatible, reason = exports.is_compatible(model, providers_digest)
+    if not compatible then
+      rspamd_logger.errx(cfg, 'external neural model incompatible: %s', reason)
+      if on_load_callback then
+        on_load_callback(nil, reason)
+      end
+      return
+    end
+
+    rspamd_logger.infox(cfg, 'loaded external neural model from %s (version=%s)',
+      url, model.model_version or 0)
+
+    -- Update cache
+    external_model_cache[cache_key] = {
+      model = model,
+      last_version = model.model_version,
+      last_load = os.time(),
+    }
+
+    -- Call user callback
+    if on_load_callback then
+      on_load_callback(model, nil)
+    end
+  end
+
+  -- Create callback map
+  local map = cfg:add_map({
+    url = url,
+    type = 'callback',
+    description = string.format('External neural model for rule %s', rule.prefix or 'default'),
+    callback = map_callback,
+    opaque_data = true,  -- Get data as rspamd_text
+  })
+
+  if not map then
+    rspamd_logger.errx(cfg, 'failed to register external neural model map for %s', url)
+    return false
+  end
+
+  -- Set sign key if configured
+  if ext_cfg.sign_key then
+    map:set_sign_key(ext_cfg.sign_key)
+  end
+
+  external_model_cache[cache_key] = {
+    map = map,
+    callbacks = { on_load_callback },
+  }
+
+  return true
+end
+
+--- Get cached model data for URL
+-- @param url model URL
+-- @return cached model data or nil
+function exports.get_cached_model(url)
+  local cached = external_model_cache[url]
+  if cached and cached.model then
+    return cached.model
+  end
+  return nil
+end
+
+--- Create external model configuration for a neural rule
+-- @param rule neural rule configuration
+-- @param providers_digest digest of providers config
+-- @return table with external model config or nil
+function exports.create_external_config(rule, providers_digest)
+  local ext = rule.external_model
+  if not ext then
+    return nil
+  end
+
+  local url = ext.url
+  if not url then
+    -- Build URL from digest if base_url is provided
+    if ext.base_url then
+      url = exports.build_model_url(ext.base_url, providers_digest)
+    else
+      rspamd_logger.errx(rspamd_config, 'external_model requires url or base_url')
+      return nil
+    end
+  end
+
+  return {
+    url = url,
+    sign_key = ext.sign_key,
+    merge_strategy = ext.merge_strategy or "interpolate",
+    merge_alpha = ext.merge_alpha or 0.5,
+    check_interval = ext.check_interval or 86400, -- 24h
+    local_fine_tune = ext.local_fine_tune ~= false,
+    min_local_samples = ext.min_local_samples or 50,
+    providers_digest = providers_digest,
+    loaded = false,
+    last_version = nil,
+    last_check = nil,
+  }
+end
+
+--- Store external model metadata in Redis for later merge
+-- @param redis redis params
+-- @param ev_base event base
+-- @param ann_key Redis key for the ANN
+-- @param model_data parsed model data
+-- @param callback function(err)
+function exports.store_base_model(redis, ev_base, ann_key, model_data, callback)
+  -- Store base model version and compressed ann_data for re-merge
+  local base_key = ann_key .. "_base"
+
+  local function store_cb(err)
+    if err then
+      rspamd_logger.errx(rspamd_config, "failed to store base model: %s", err)
+    end
+    if callback then
+      callback(err)
+    end
+  end
+
+  -- Ensure ann_data is rspamd_text for opaque storage
+  local ann_data = model_data.ann_data
+  if type(ann_data) == 'string' then
+    -- Already compressed, convert to text
+    ann_data = rspamd_text.fromstring(ann_data)
+  end
+
+  -- Store base version and ann_data
+  lua_redis.redis_make_request_taskless(ev_base,
+    rspamd_config,
+    redis,
+    nil,
+    true, -- is write
+    store_cb,
+    'HMSET',
+    {
+      base_key,
+      'version', tostring(model_data.model_version or 0),
+      'ann_data', ann_data,
+      'providers_digest', model_data.providers_digest or '',
+      'created_at', tostring(model_data.created_at or os.time()),
+    },
+    { opaque_data = true }
+  )
+end
+
+--- Load base model from Redis for re-merge
+-- @param redis redis params
+-- @param ev_base event base
+-- @param ann_key Redis key for the ANN
+-- @param callback function(err, model_data)
+function exports.load_base_model(redis, ev_base, ann_key, callback)
+  local base_key = ann_key .. "_base"
+
+  local function load_cb(err, data)
+    if err then
+      callback(err, nil)
+      return
+    end
+
+    if type(data) ~= 'table' then
+      callback("no base model found", nil)
+      return
+    end
+
+    local model_data = {
+      model_version = tonumber(data[1]) or 0,
+      ann_data = data[2], -- rspamd_text
+      providers_digest = data[3],
+      created_at = tonumber(data[4]) or 0,
+    }
+
+    callback(nil, model_data)
+  end
+
+  lua_redis.redis_make_request_taskless(ev_base,
+    rspamd_config,
+    redis,
+    nil,
+    false, -- is write
+    load_cb,
+    'HMGET',
+    { base_key, 'version', 'ann_data', 'providers_digest', 'created_at' },
+    { opaque_data = true }
+  )
+end
+
+--- Serialize model to msgpack format
+-- @param ann kann_t object
+-- @param pca optional PCA tensor
+-- @param providers_digest providers config digest
+-- @param opts optional { norm_stats, roc_thresholds }
+-- @return rspamd_text with compressed msgpack data
+function exports.serialize_model(ann, pca, providers_digest, opts)
+  opts = opts or {}
+
+  -- Save ANN to memory
+  local ann_text = ann:save()
+  local ann_compressed = rspamd_util.zstd_compress(ann_text)
+
+  local model = {
+    magic = MODEL_MAGIC,
+    version = MODEL_FORMAT_VERSION,
+    model_version = opts.model_version or 1,
+    providers_digest = providers_digest,
+    ann_data = ann_compressed,
+    created_at = os.time(),
+  }
+
+  if pca then
+    local pca_text = pca:save()
+    model.pca_data = rspamd_util.zstd_compress(pca_text)
+  end
+
+  if opts.norm_stats then
+    model.norm_stats = opts.norm_stats
+  end
+
+  if opts.roc_thresholds then
+    model.roc_thresholds = opts.roc_thresholds
+  end
+
+  return ucl.to_format(model, 'msgpack')
+end
+
+exports.MODEL_MAGIC = MODEL_MAGIC
+exports.MODEL_FORMAT_VERSION = MODEL_FORMAT_VERSION
+
+return exports
index 5d4ee4e2de9bca882b5497b9febba34425253e2c..3ba3799da3a0d6d13b3ac2ee5a336aa568ea0b79 100644 (file)
@@ -24,6 +24,7 @@ local rspamd_logger = require "rspamd_logger"
 local rspamd_tensor = require "rspamd_tensor"
 local rspamd_util = require "rspamd_util"
 local ucl = require "ucl"
+local neural_external = require "lua_neural_external"
 
 local N = 'neural'
 
@@ -76,6 +77,8 @@ local default_options = {
     per_provider_pca = false,    -- if true, apply PCA per provider before fusion (not active yet)
   },
   disable_symbols_input = false, -- when true, do not use symbols provider unless explicitly listed
+  -- External pretrained model support
+  external_model = nil,          -- external model configuration (see lua_neural_external)
 }
 
 -- Rule structure:
@@ -1506,4 +1509,6 @@ return {
   result_to_vector = result_to_vector,
   settings = settings,
   spawn_train = spawn_train,
+  -- External model support
+  neural_external = neural_external,
 }
diff --git a/lualib/rspamadm/neural_export.lua b/lualib/rspamadm/neural_export.lua
new file mode 100644 (file)
index 0000000..3eb94db
--- /dev/null
@@ -0,0 +1,438 @@
+--[[
+Copyright (c) 2026, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[
+Export neural network model from Redis to RNM1 format for distribution.
+
+This tool extracts a trained neural model from Redis and saves it as
+a portable msgpack file that can be loaded by other rspamd instances.
+
+Usage:
+  rspamadm neural_export -o model.rnm
+
+Options:
+  -c, --config     Path to rspamd config (to get Redis connection)
+  -r, --rule       Rule name to export (default: 'default')
+  -s, --settings   Settings ID to export (default: 'default')
+  -o, --output     Output file path (required)
+  -v, --version    Specific model version to export (default: latest)
+  --list           List available models instead of exporting
+  --digest         Also print providers digest for URL construction
+]]--
+
+local rspamd_logger = require "rspamd_logger"
+local lua_redis = require "lua_redis"
+local argparse = require "argparse"
+local ucl = require "ucl"
+local rspamd_util = require "rspamd_util"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_tensor = require "rspamd_tensor"
+
+local parser = argparse()
+    :name "rspamadm neural_export"
+    :description "Export neural network model from Redis to RNM1 format"
+    :help_description_margin(32)
+
+parser:option "-c --config"
+      :description "Path to config file"
+      :argname("<cfg>")
+      :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:option "-r --rule"
+      :description "Rule name to export"
+      :argname("<rule>")
+      :default("default")
+parser:option "-s --settings"
+      :description "Settings ID to export"
+      :argname("<id>")
+      :default("default")
+parser:option "-o --output"
+      :description "Output file path"
+      :argname("<file>")
+parser:option "-v --version"
+      :description "Specific model version to export"
+      :argname("<version>")
+      :convert(tonumber)
+parser:flag "--list"
+      :description "List available models instead of exporting"
+parser:flag "--digest"
+      :description "Print providers digest for URL construction"
+
+-- Model format constants (must match lua_neural_external)
+local MODEL_MAGIC = "RNM1"
+local MODEL_FORMAT_VERSION = 1
+
+--- Load config and initialize Redis connection
+local function init_redis(opts)
+  local _r, err = rspamd_config:load_ucl(opts['config'])
+
+  if not _r then
+    rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+    return nil
+  end
+
+  _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+  if not _r then
+    rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+    return nil
+  end
+
+  local redis_params = lua_redis.parse_redis_server('neural')
+  if not redis_params then
+    rspamd_logger.errx('cannot get Redis configuration for neural module')
+    return nil
+  end
+
+  return redis_params
+end
+
+--- Get neural rule configuration
+local function get_neural_rule(opts)
+  local neural_opts = rspamd_config:get_all_opt('neural')
+  if not neural_opts then
+    rspamd_logger.errx('no neural configuration found')
+    return nil
+  end
+
+  local rules = neural_opts['rules'] or {}
+  local rule_name = opts['rule']
+
+  -- Handle legacy config (neural_opts itself is the rule)
+  if not rules['default'] and neural_opts.train then
+    rules['default'] = neural_opts
+  end
+
+  local rule = rules[rule_name]
+  if not rule then
+    rspamd_logger.errx('rule "%s" not found in neural config', rule_name)
+    return nil
+  end
+
+  return rule
+end
+
+--- Build Redis prefix for a rule/settings combination
+local function build_redis_prefix(rule, settings_name)
+  local neural_common = require "plugins/neural"
+  return neural_common.redis_ann_prefix({
+    prefix = rule.prefix or 'default'
+  }, settings_name)
+end
+
+--- List available models from Redis
+local function list_models(redis_params, prefix, callback)
+  local function members_cb(err, data)
+    if err then
+      callback(err, nil)
+    else
+      callback(nil, data)
+    end
+  end
+
+  -- Get all profiles from sorted set (most recent first)
+  lua_redis.redis_make_request_taskless(nil,
+    rspamd_config,
+    redis_params,
+    nil,
+    false,
+    members_cb,
+    'ZREVRANGE',
+    { prefix, '0', '-1', 'WITHSCORES' }
+  )
+end
+
+--- Load ANN data from Redis
+local function load_ann_from_redis(redis_params, ann_key, callback)
+  local function data_cb(err, data)
+    if err then
+      callback(err, nil)
+      return
+    end
+
+    if type(data) ~= 'table' then
+      callback("no data found at key: " .. ann_key, nil)
+      return
+    end
+
+    -- data[1] = ann, data[2] = roc_thresholds, data[3] = pca,
+    -- data[4] = providers_meta, data[5] = norm_stats
+    local result = {
+      ann_data = data[1],
+      roc_thresholds = data[2],
+      pca_data = data[3],
+      providers_meta = data[4],
+      norm_stats = data[5],
+    }
+
+    -- Decompress ann_data
+    if result.ann_data then
+      local dec_err, dec_data = rspamd_util.zstd_decompress(result.ann_data)
+      if not dec_err and dec_data then
+        result.ann = rspamd_kann.load(dec_data)
+      end
+    end
+
+    -- Decompress and load PCA
+    if result.pca_data then
+      local dec_err, dec_data = rspamd_util.zstd_decompress(result.pca_data)
+      if not dec_err and dec_data then
+        result.pca = rspamd_tensor.load(dec_data)
+      end
+    end
+
+    -- Parse JSON fields
+    if result.roc_thresholds then
+      local roc_parser = ucl.parser()
+      local ok = roc_parser:parse_text(result.roc_thresholds)
+      if ok then
+        result.roc_thresholds = roc_parser:get_object()
+      end
+    end
+
+    if result.norm_stats then
+      local norm_parser = ucl.parser()
+      local ok = norm_parser:parse_text(result.norm_stats)
+      if ok then
+        result.norm_stats = norm_parser:get_object()
+      end
+    end
+
+    callback(nil, result)
+  end
+
+  lua_redis.redis_make_request_taskless(nil,
+    rspamd_config,
+    redis_params,
+    nil,
+    false,
+    data_cb,
+    'HMGET',
+    { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' },
+    { opaque_data = true }
+  )
+end
+
+--- Load profile from Redis sorted set
+local function load_profile(profile_str)
+  local profile_parser = ucl.parser()
+  local ok = profile_parser:parse_string(profile_str)
+  if not ok then
+    return nil
+  end
+  return profile_parser:get_object()
+end
+
+--- Serialize model to RNM1 format
+local function serialize_model(ann, pca, providers_digest, opts)
+  opts = opts or {}
+
+  -- Save ANN to memory
+  local ann_text = ann:save()
+  local ann_compressed = rspamd_util.zstd_compress(ann_text)
+
+  local model = {
+    magic = MODEL_MAGIC,
+    version = MODEL_FORMAT_VERSION,
+    model_version = opts.model_version or 1,
+    providers_digest = providers_digest,
+    ann_data = ann_compressed,
+    created_at = os.time(),
+  }
+
+  if pca then
+    local pca_text = pca:save()
+    model.pca_data = rspamd_util.zstd_compress(pca_text)
+  end
+
+  if opts.norm_stats then
+    model.norm_stats = opts.norm_stats
+  end
+
+  if opts.roc_thresholds then
+    model.roc_thresholds = opts.roc_thresholds
+  end
+
+  return ucl.to_format(model, 'msgpack')
+end
+
+local function handler(args)
+  local opts = parser:parse(args)
+
+  -- Initialize
+  local redis_params = init_redis(opts)
+  if not redis_params then
+    os.exit(1)
+  end
+
+  local rule = get_neural_rule(opts)
+  if not rule then
+    os.exit(1)
+  end
+
+  -- Get providers digest
+  local neural_common = require "plugins/neural"
+  local providers_digest = neural_common.providers_config_digest(rule.providers, rule)
+
+  -- Print digest if requested
+  if opts['digest'] then
+    print(string.format("Providers digest: %s", providers_digest))
+    if not opts['list'] and not opts['output'] then
+      return
+    end
+  end
+
+  local settings_name = opts['settings']
+  local prefix = build_redis_prefix(rule, settings_name)
+
+  -- List mode: show available models
+  if opts['list'] then
+    print(string.format("\nAvailable models for rule '%s', settings '%s':", opts['rule'], settings_name))
+    print(string.format("Redis prefix: %s", prefix))
+    print("")
+
+    local co = coroutine.create(function()
+      list_models(redis_params, prefix, function(err, data)
+        if err then
+          rspamd_logger.errx('failed to list models: %s', err)
+          return
+        end
+
+        if not data or #data == 0 then
+          print("No models found")
+          return
+        end
+
+        -- Parse profiles (data is alternating: profile, score, profile, score, ...)
+        for i = 1, #data, 2 do
+          local profile_str = data[i]
+          local score = tonumber(data[i + 1]) or 0
+          local profile = load_profile(profile_str)
+
+          if profile then
+            local ts = os.date("%Y-%m-%d %H:%M:%S", score)
+            print(string.format("  Version: %s, Key: %s, Updated: %s",
+              profile.version or 0,
+            profile.redis_key or "?",
+            ts))
+          end
+        end
+      end)
+    end)
+    coroutine.resume(co)
+    return
+  end
+
+  -- Export mode: need output file
+  if not opts['output'] then
+    rspamd_logger.errx('output file is required for export (use -o <file>)')
+    os.exit(1)
+  end
+
+  local output_file = opts['output']
+  local target_version = opts['version']
+
+  -- Find and load the model
+  local co = coroutine.create(function()
+    list_models(redis_params, prefix, function(err, data)
+      if err then
+        rspamd_logger.errx('failed to list models: %s', err)
+        return
+      end
+
+      if not data or #data == 0 then
+        rspamd_logger.errx('no models found in Redis')
+        return
+      end
+
+      -- Find the target version (or use latest)
+      local selected_profile
+      local selected_score = 0
+
+      for i = 1, #data, 2 do
+        local profile_str = data[i]
+        local score = tonumber(data[i + 1]) or 0
+        local profile = load_profile(profile_str)
+
+        if profile then
+          if target_version then
+            -- Looking for specific version
+            if profile.version == target_version then
+              selected_profile = profile
+              selected_score = score
+              break
+            end
+          else
+            -- Use latest (highest score in sorted set = most recent)
+            if not selected_profile or score > selected_score then
+              selected_profile = profile
+              selected_score = score
+            end
+          end
+        end
+      end
+
+      if not selected_profile then
+        if target_version then
+          rspamd_logger.errx('model version %s not found', target_version)
+        else
+          rspamd_logger.errx('no suitable model found')
+        end
+        return
+      end
+
+      local ann_key = selected_profile.redis_key
+      rspamd_logger.messagex('Loading model from key: %s (version %s)',
+        ann_key, selected_profile.version or 0)
+
+      -- Load ANN data
+      load_ann_from_redis(redis_params, ann_key, function(load_err, model_data)
+        if load_err then
+          rspamd_logger.errx('failed to load model data: %s', load_err)
+          return
+        end
+
+        if not model_data.ann then
+          rspamd_logger.errx('failed to load ANN from Redis')
+          return
+        end
+
+        -- Serialize to RNM1 format
+        local rnm_data = serialize_model(model_data.ann, model_data.pca, providers_digest, {
+          model_version = selected_profile.version or 1,
+          norm_stats = model_data.norm_stats,
+          roc_thresholds = model_data.roc_thresholds,
+        })
+
+        -- Write to file
+        local out = assert(io.open(output_file, "wb"))
+        out:write(rnm_data)
+        out:close()
+
+        rspamd_logger.messagex('Exported model to: %s', output_file)
+        rspamd_logger.messagex('  Model version: %s', selected_profile.version or 0)
+        rspamd_logger.messagex('  Providers digest: %s', providers_digest:sub(1, 16) .. '...')
+        rspamd_logger.messagex('  File size: %s bytes', #rnm_data)
+      end)
+    end)
+  end)
+  coroutine.resume(co)
+end
+
+return {
+  name = "neural_export",
+  aliases = { "neural_export" },
+  handler = handler,
+  description = parser._description
+}
index e9b91e7ee62168848ae4a33cf0f748846ae1ec21..51742b22fba0a761c89ed5966f042ea9e001cc9e 100644 (file)
@@ -167,11 +167,15 @@ LUA_FUNCTION_DEF(kann, destroy);
 LUA_FUNCTION_DEF(kann, save);
 LUA_FUNCTION_DEF(kann, train1);
 LUA_FUNCTION_DEF(kann, apply1);
+LUA_FUNCTION_DEF(kann, merge_weights);
+LUA_FUNCTION_DEF(kann, is_compatible);
 
 static luaL_reg rspamd_kann_m[] = {
        LUA_INTERFACE_DEF(kann, save),
        LUA_INTERFACE_DEF(kann, train1),
        LUA_INTERFACE_DEF(kann, apply1),
+       LUA_INTERFACE_DEF(kann, merge_weights),
+       LUA_INTERFACE_DEF(kann, is_compatible),
        {"__gc", lua_kann_destroy},
        {NULL, NULL},
 };
@@ -1434,4 +1438,70 @@ lua_kann_apply1(lua_State *L)
        }
 
        return 1;
+}
+
+/***
+ * @function kann:merge_weights(other_ann, alpha)
+ * Merge weights from another ANN into this one using linear interpolation.
+ * w_new = (1 - alpha) * w_self + alpha * w_other
+ * @param {kann} other_ann source ANN to merge from
+ * @param {number} alpha weight for source ANN (0.0 - 1.0)
+ * @return {boolean} true on success, false on error
+ */
+static int
+lua_kann_merge_weights(lua_State *L)
+{
+       kann_t *self = lua_check_kann(L, 1);
+       kann_t *other = lua_check_kann(L, 2);
+       double alpha = luaL_checknumber(L, 3);
+
+       if (self && other) {
+               if (alpha < 0.0 || alpha > 1.0) {
+                       return luaL_error(L, "alpha must be between 0.0 and 1.0, got %f", alpha);
+               }
+
+               /* Check compatibility first */
+               if (!kann_is_compatible(self, other)) {
+                       lua_pushboolean(L, false);
+                       lua_pushstring(L, "incompatible ANN architectures");
+                       return 2;
+               }
+
+               int ret = kann_merge_weights(self, other, (float) alpha);
+
+               if (ret == 0) {
+                       lua_pushboolean(L, true);
+                       return 1;
+               }
+               else {
+                       lua_pushboolean(L, false);
+                       lua_pushstring(L, "merge failed");
+                       return 2;
+               }
+       }
+       else {
+               return luaL_error(L, "invalid arguments: two kann objects required");
+       }
+}
+
+/***
+ * @function kann:is_compatible(other_ann)
+ * Check if two ANNs have compatible architecture for weight merging.
+ * @param {kann} other_ann ANN to check compatibility with
+ * @return {boolean} true if compatible, false otherwise
+ */
+static int
+lua_kann_is_compatible(lua_State *L)
+{
+       kann_t *self = lua_check_kann(L, 1);
+       kann_t *other = lua_check_kann(L, 2);
+
+       if (self && other) {
+               int ret = kann_is_compatible(self, other);
+               lua_pushboolean(L, ret == 1);
+               return 1;
+       }
+       else {
+               return luaL_error(L, "invalid arguments: two kann objects required");
+       }
 }
\ No newline at end of file
index 9994d411a24c395afd9a278edb06d5f3aa47fbf5..e5b3d04763884e120256a89560f6b58eff17db7f 100644 (file)
@@ -21,6 +21,7 @@ local lua_util = require "lua_util"
 local lua_verdict = require "lua_verdict"
 local neural_common = require "plugins/neural"
 local neural_learn = require "lua_neural_learn"
+local neural_external = require "lua_neural_external"
 local rspamd_kann = require "rspamd_kann"
 local rspamd_logger = require "rspamd_logger"
 local rspamd_tensor = require "rspamd_tensor"
@@ -746,6 +747,172 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
   )
 end
 
+--- External model support functions
+
+-- Apply loaded external model to settings element
+-- @param rule neural rule configuration
+-- @param set settings element
+-- @param model parsed external model data
+-- @param ev_base event base (optional, for storing base model)
+local function apply_external_model(rule, set, model, ev_base)
+  local ext_cfg = rule.external_model
+  if not ext_cfg or not model then
+    return false
+  end
+
+  -- Load external ANN
+  local ext_ann, ann_err = neural_external.load_ann(model)
+  if not ext_ann then
+    rspamd_logger.errx(rspamd_config, 'failed to load external ANN for %s:%s: %s',
+      rule.prefix, set.name, ann_err or "unknown")
+    return false
+  end
+
+  -- Check if we have a local ANN to merge with
+  if set.ann and set.ann.ann then
+    -- Check architecture compatibility
+    local ok = ext_ann:is_compatible(set.ann.ann)
+    if not ok then
+      rspamd_logger.warnx(rspamd_config,
+        'external ANN architecture incompatible with local ANN for %s:%s, using external only',
+        rule.prefix, set.name)
+      set.ann.ann = ext_ann
+      set.ann.version = model.model_version or 1
+      set.ann.external_version = model.model_version
+      set.ann.external_source = ext_cfg.url
+      return true
+    end
+
+    -- Merge weights
+    local alpha = ext_cfg.merge_alpha or 0.5
+    local merged, merge_err = ext_ann:merge_weights(set.ann.ann, alpha)
+    if not merged then
+      rspamd_logger.errx(rspamd_config, 'failed to merge ANNs for %s:%s: %s',
+        rule.prefix, set.name, merge_err or "unknown")
+      return false
+    end
+
+    rspamd_logger.infox(rspamd_config,
+      'merged external model (version=%s, alpha=%s) with local ANN for %s:%s',
+      model.model_version, alpha, rule.prefix, set.name)
+
+    -- Update ANN reference
+    set.ann.ann = merged
+    set.ann.version = (set.ann.version or 0) + 1
+    set.ann.external_version = model.model_version
+    set.ann.external_source = ext_cfg.url
+
+    -- Store base model for future re-merge
+    if ev_base then
+      neural_external.store_base_model(rule.redis, ev_base, set.ann.redis_key, model, function(store_err)
+        if store_err then
+          rspamd_logger.warnx(rspamd_config, 'failed to store base model: %s', store_err)
+        end
+      end)
+    end
+  else
+    -- No local ANN, just use external
+    rspamd_logger.infox(rspamd_config,
+      'loaded external model (version=%s) as initial ANN for %s:%s',
+      model.model_version, rule.prefix, set.name)
+
+    set.ann = {
+      version = model.model_version or 1,
+      redis_key = neural_common.new_ann_key(rule, set, model.model_version or 1),
+      external_version = model.model_version,
+      external_source = ext_cfg.url,
+      ann = ext_ann,
+      providers_digest = ext_cfg.providers_digest,
+    }
+
+    -- Store base model for future re-merge
+    if ev_base then
+      neural_external.store_base_model(rule.redis, ev_base, set.ann.redis_key, model, function(store_err)
+        if store_err then
+          rspamd_logger.warnx(rspamd_config, 'failed to store base model: %s', store_err)
+        end
+      end)
+    end
+  end
+
+  -- Load PCA if present
+  local pca = neural_external.load_pca(model)
+  if pca then
+    set.ann.pca = pca
+  end
+
+  -- Copy normalization stats
+  if model.norm_stats then
+    set.ann.norm_stats = model.norm_stats
+  end
+
+  -- Copy ROC thresholds
+  if model.roc_thresholds then
+    set.ann.roc_thresholds = model.roc_thresholds
+  end
+
+  -- Update external model state
+  ext_cfg.last_version = model.model_version
+  ext_cfg.loaded = true
+
+  return true
+end
+
+-- Register external model map for a rule
+-- This should be called at config time
+-- @param rule neural rule configuration
+-- @return boolean success
+local function register_external_model_map(rule)
+  local ext_cfg = rule.external_model
+  if not ext_cfg or not ext_cfg.url then
+    return false
+  end
+
+  -- Store rule reference for callbacks
+  local rule_ref = rule
+
+  -- Map callback: called when external model is loaded/reloaded
+  local function on_model_load(model, err)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'external model load failed for %s: %s',
+        rule_ref.prefix, err)
+      return
+    end
+
+    -- Apply model to all settings
+    for _, set in pairs(rule_ref.settings) do
+      if type(set) == 'table' then
+        apply_external_model(rule_ref, set, model, nil)
+      end
+    end
+  end
+
+  return neural_external.register_model_map(rspamd_config, rule, ext_cfg.providers_digest, on_model_load)
+end
+
+-- Check external model updates (called periodically by map infrastructure)
+-- This is now mostly handled by the map's automatic reload mechanism
+local function check_external_model(worker, cfg, ev_base, rule)
+  local ext_cfg = rule.external_model
+  if not ext_cfg then
+    return
+  end
+
+  -- Check if we have a cached model from the map
+  local cached_model = neural_external.get_cached_model(ext_cfg.url)
+  if cached_model and cached_model.model_version ~= ext_cfg.last_version then
+    rspamd_logger.infox(cfg, 'external model updated for %s: version %s -> %s',
+      rule.prefix, ext_cfg.last_version or 0, cached_model.model_version)
+
+    -- Apply to all settings
+    for _, set in pairs(rule.settings) do
+      if type(set) == 'table' then
+        apply_external_model(rule, set, cached_model, ev_base)
+      end
+    end
+  end
+end
+
 -- Used to check an element in Redis serialized as JSON
 -- for some specific rule + some specific setting
 -- This function tries to load more fresh or more specific ANNs in lieu of
@@ -1192,8 +1359,24 @@ for k, r in pairs(rules) do
     end
   end
 
+  -- External model configuration
+  if rule_elt.external_model then
+    local providers_digest = neural_common.providers_config_digest(rule_elt.providers, rule_elt)
+    rule_elt.external_model = neural_external.create_external_config(rule_elt, providers_digest)
+    if rule_elt.external_model then
+      rspamd_logger.infox(rspamd_config, "configured external model for rule %s: url=%s, merge_alpha=%s",
+        k, rule_elt.external_model.url, rule_elt.external_model.merge_alpha)
+    end
+  end
+
   rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
   settings.rules[k] = rule_elt
+
+  -- Register external model map if configured
+  if rule_elt.external_model then
+    register_external_model_map(rule_elt)
+  end
+
   rspamd_config:set_metric_symbol({
     name = rule_elt.symbol_spam,
     score = 0.0,
@@ -1250,6 +1433,8 @@ for _, rule in pairs(settings.rules) do
         function(_, _)
           -- Clean old ANNs
           cleanup_anns(rule, cfg, ev_base)
+          -- Check for external model updates
+          check_external_model(worker, cfg, ev_base, rule)
           return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
             'try_train_ann')
         end)