]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Fasttext embed: multi-scale conv1d pooling for text features 5903/head
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 23 Feb 2026 09:14:35 +0000 (09:14 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 23 Feb 2026 09:14:35 +0000 (09:14 +0000)
Add conv1d output mode to the fasttext_embed provider that applies
multi-scale max-over-time pooling over sliding word windows in Lua,
producing compact feature vectors for the neural plugin's dense ANN.

For each kernel size (default {1, 3, 5}), word vectors are averaged
within sliding windows, then max-pooled across positions per channel.
Each scale's features are L2-normalized independently for balanced
contribution. This replaces the previous approach of feeding raw NCW
matrices into KANN conv1d layers.

Also adds max1d and input3d layer bindings to the KANN Lua API, and
includes conv1d settings (kernel_sizes, conv_pooling, max_words) in
the providers_config_digest for automatic retraining on config change.

lualib/plugins/neural.lua
lualib/plugins/neural/providers/fasttext_embed.lua
src/lua/lua_kann.c

index 02bbe1a46d049a1447650739259bdde8ccef057c..7b44d3463f38b4583637a7f0946611e5b495a329 100644 (file)
@@ -266,6 +266,16 @@ local function create_embedding_ann(n, rule)
   return rspamd_kann.new.kann(t)
 end
 
+-- Conv1d ANN: uses the enhanced embedding architecture.
+-- The actual convolution (multi-scale max-over-time pooling) is done in the
+-- fasttext_embed provider, which produces compact feature vectors (n_scales * channels).
+-- The ANN itself is a simple dense network on these pre-convolved features.
+local function create_conv1d_ann(n, rule)
+  lua_util.debugm(N, rspamd_config,
+    'creating conv1d ANN: %s pre-convolved inputs', n)
+  return create_embedding_ann(n, rule)
+end
+
 -- Detects if rule uses LLM embeddings provider
 local function uses_llm_embeddings(rule)
   if not rule.providers then
@@ -281,6 +291,12 @@ end
 
 -- Main ANN factory function - auto-selects architecture based on rule configuration
 local function create_ann(n, nlayers, rule)
+  -- Check for conv1d architecture first
+  if rule.conv1d then
+    lua_util.debugm(N, rspamd_config, 'creating conv1d ANN with %s inputs', n)
+    return create_conv1d_ann(n, rule)
+  end
+
   -- Check if we should use the enhanced embedding architecture
   -- Conditions: has LLM provider, or explicit multi-layer config, or large input dimension
   local use_embedding_arch = uses_llm_embeddings(rule)
@@ -755,6 +771,14 @@ local function providers_config_digest(providers_cfg, rule)
       entry.max_tokens = max_tokens
     end
 
+    -- Conv1d feature extraction settings affect output dimensions
+    if p.output_mode == 'conv1d' then
+      entry.output_mode = 'conv1d'
+      entry.max_words = p.max_words or 32
+      entry.kernel_sizes = p.kernel_sizes or { 1, 3, 5 }
+      entry.conv_pooling = p.conv_pooling or 'max'
+    end
+
     norm.providers[i] = entry
   end
   return lua_util.unordered_table_digest(norm)
index a38db985425c60dc80ee04da77928e2de5d5d9ba..d15d3940a8b646e8e2974510edb949f6cfa9dbd5 100644 (file)
@@ -12,6 +12,16 @@ Pooling modes (pooling = "mean_max" by default):
   "mean"     - average of word vectors (classic fasttext sentence vector)
   "mean_max" - concatenation of mean and element-wise max pooling
 
+Conv1d output mode (output_mode = "conv1d"):
+  Multi-scale max-over-time pooling over sliding word windows.
+  For each kernel size, averages word vectors in a window, then
+  max-pools across all window positions per channel.
+  Produces compact features: n_scales * pools_per_scale * channels.
+  Options:
+    kernel_sizes = [1, 3, 5]  - window sizes (default)
+    conv_pooling = "max"      - per-scale pooling: "max", "mean", "mean_max"
+    max_words = 32            - max word positions (default)
+
 SIF (Smooth Inverse Frequency) weighting is enabled by default (sif_weight = true).
 Common words (the, is, a) get near-zero weight, distinctive words get high weight.
 Tune with sif_a parameter (default 1e-3). Set sif_weight = false to disable.
@@ -260,6 +270,149 @@ local function compute_pooled_vectors(model, words, pooling, sif_a)
   return mean_vec
 end
 
+-- Multi-scale pooling for conv1d feature extraction.
+-- Instead of storing raw NCW matrices and relying on KANN conv1d,
+-- we apply fixed convolution-like operations in Lua and store compact features.
+--
+-- For each window size k in kernel_sizes (default {1, 3, 5}):
+--   1. Slide a window of k words over the sequence
+--   2. Average word vectors within each window (like a fixed conv filter)
+--   3. Max-pool (and optionally mean-pool) over all window positions
+-- Each scale's features are L2-normalized independently for balanced contribution.
+-- Output: flat table of n_scales * n_pool * C floats (e.g., 3 * 2 * 100 = 600 for mean_max)
+local function compute_conv1d_features(models, words, max_words, sif_a, opts)
+  opts = opts or {}
+  local use_sif = sif_a and sif_a > 0
+  local nwords = math.min(#words, max_words)
+  local kernel_sizes = opts.kernel_sizes or { 1, 3, 5 }
+  local conv_pooling = opts.conv_pooling or 'mean_max'
+  local need_mean = (conv_pooling == 'mean_max' or conv_pooling == 'mean')
+  local need_max = (conv_pooling == 'mean_max' or conv_pooling == 'max')
+
+  if nwords == 0 then
+    return nil, 0, 0
+  end
+
+  -- Compute total channels (sum of all model dimensions)
+  local total_channels = 0
+  for _, m in ipairs(models) do
+    total_channels = total_channels + m.model:get_dimension()
+  end
+
+  -- Collect word vectors: word_vecs[w][c] for word position w, channel c
+  local word_vecs = {}
+  for w = 1, nwords do
+    local wv_all = {}
+    for _, m in ipairs(models) do
+      local wv = m.model:get_word_vector(words[w])
+      local dim = m.model:get_dimension()
+      if wv and #wv >= dim then
+        local weight = 1.0
+        if use_sif then
+          local freq = m.model:get_word_frequency(words[w])
+          if freq > 0 then
+            weight = sif_a / (sif_a + freq)
+          end
+        end
+        for d = 1, dim do
+          wv_all[#wv_all + 1] = wv[d] * weight
+        end
+      else
+        for _ = 1, dim do
+          wv_all[#wv_all + 1] = 0.0
+        end
+      end
+    end
+    word_vecs[w] = wv_all
+  end
+
+  -- Multi-scale pooling with per-scale L2 normalization
+  local output = {}
+
+  for _, k in ipairs(kernel_sizes) do
+    local scale_mean = need_mean and {} or nil
+    local scale_max = need_max and {} or nil
+
+    -- Initialize per-scale accumulators
+    for c = 1, total_channels do
+      if scale_mean then
+        scale_mean[c] = 0.0
+      end
+      if scale_max then
+        scale_max[c] = -math.huge
+      end
+    end
+
+    -- Slide window of size k over word positions
+    local n_windows = nwords - k + 1
+    if n_windows < 1 then
+      -- Sequence too short for this kernel; treat each word as a window
+      n_windows = nwords
+      for c = 1, total_channels do
+        for w = 1, nwords do
+          local val = word_vecs[w][c] or 0.0
+          if scale_mean then
+            scale_mean[c] = scale_mean[c] + val
+          end
+          if scale_max and val > scale_max[c] then
+            scale_max[c] = val
+          end
+        end
+        if scale_mean then
+          scale_mean[c] = scale_mean[c] / nwords
+        end
+      end
+    else
+      for c = 1, total_channels do
+        for start = 1, n_windows do
+          -- Average word vectors within this window
+          local sum = 0.0
+          for w = start, start + k - 1 do
+            sum = sum + (word_vecs[w][c] or 0.0)
+          end
+          local avg = sum / k
+          if scale_mean then
+            scale_mean[c] = scale_mean[c] + avg
+          end
+          if scale_max and avg > scale_max[c] then
+            scale_max[c] = avg
+          end
+        end
+        if scale_mean then
+          scale_mean[c] = scale_mean[c] / n_windows
+        end
+      end
+    end
+
+    -- L2-normalize and append mean features for this scale
+    if scale_mean then
+      local norm = 0.0
+      for c = 1, total_channels do
+        norm = norm + scale_mean[c] * scale_mean[c]
+      end
+      norm = math.sqrt(norm)
+      for c = 1, total_channels do
+        output[#output + 1] = norm > 0 and (scale_mean[c] / norm) or 0.0
+      end
+    end
+
+    -- L2-normalize and append max features for this scale
+    if scale_max then
+      local norm = 0.0
+      for c = 1, total_channels do
+        norm = norm + scale_max[c] * scale_max[c]
+      end
+      norm = math.sqrt(norm)
+      for c = 1, total_channels do
+        output[#output + 1] = norm > 0 and (scale_max[c] / norm) or 0.0
+      end
+    end
+  end
+
+  local pools_per_scale = (need_mean and 1 or 0) + (need_max and 1 or 0)
+  return output, total_channels, pools_per_scale
+end
+
 neural_common.register_provider('fasttext_embed', {
   collect_async = function(task, ctx, cont)
     local pcfg = ctx.config or {}
@@ -318,6 +471,54 @@ neural_common.register_provider('fasttext_embed', {
       return
     end
 
+    -- Conv1d output mode: multi-scale max-over-time pooling.
+    -- For each kernel size, averages word vectors in sliding windows, then
+    -- max-pools across all positions per channel. Produces compact features:
+    -- n_scales * pools_per_scale * total_channels.
+    if pcfg.output_mode == 'conv1d' then
+      local max_words = pcfg.max_words or 32
+      local sif_a = pcfg.sif_a
+      if sif_a == nil then
+        sif_a = (pcfg.sif_weight ~= false) and 1e-3 or 0
+      end
+
+      local kernel_sizes = pcfg.kernel_sizes or { 1, 3, 5 }
+      local conv_pooling = pcfg.conv_pooling or 'max'
+      local model_names = {}
+      for _, m in ipairs(models) do
+        model_names[#model_names + 1] = m.lang
+      end
+
+      local combined_vec, total_channels, pools_per_scale = compute_conv1d_features(
+        models, words, max_words, sif_a,
+        { kernel_sizes = kernel_sizes, conv_pooling = conv_pooling })
+
+      if not combined_vec or #combined_vec == 0 then
+        rspamd_logger.debugm(N, task, 'fasttext_embed: conv1d produced empty features; skip')
+        cont(nil)
+        return
+      end
+
+      local meta = {
+        name = pcfg.name or 'fasttext_embed',
+        type = 'fasttext_embed',
+        output_mode = 'conv1d',
+        channels = total_channels,
+        n_scales = #kernel_sizes,
+        pools_per_scale = pools_per_scale,
+        dim = #combined_vec,
+        weight = ctx.weight or 1.0,
+        models = table.concat(model_names, '+'),
+      }
+
+      rspamd_logger.debugm(N, task,
+        'fasttext_embed: conv1d k=%s pool=%s dim=%s (%s models, %s words)',
+        table.concat(kernel_sizes, ','), conv_pooling,
+        #combined_vec, #models, math.min(#words, max_words))
+      cont(combined_vec, meta)
+      return
+    end
+
     local pooling = pcfg.pooling or 'mean_max'
     -- SIF weighting: enabled by default with a=1e-3
     local sif_a = pcfg.sif_a
index 9772aecd39164cc91b6e24db3cd4192a31ebdb55..e9b91e7ee62168848ae4a33cf0f748846ae1ec21 100644 (file)
@@ -66,6 +66,8 @@ KANN_LAYER_DEF(lstm);
 KANN_LAYER_DEF(gru);
 KANN_LAYER_DEF(conv2d);
 KANN_LAYER_DEF(conv1d);
+KANN_LAYER_DEF(max1d);
+KANN_LAYER_DEF(input3d);
 KANN_LAYER_DEF(cost);
 
 static int lua_kann_layer_layerdropout(lua_State *L); /* forward declaration */
@@ -80,6 +82,8 @@ static luaL_reg rspamd_kann_layers_f[] = {
        KANN_LAYER_INTERFACE(gru),
        KANN_LAYER_INTERFACE(conv2d),
        KANN_LAYER_INTERFACE(conv1d),
+       KANN_LAYER_INTERFACE(max1d),
+       KANN_LAYER_INTERFACE(input3d),
        KANN_LAYER_INTERFACE(cost),
        {NULL, NULL},
 };
@@ -613,9 +617,75 @@ lua_kann_layer_conv1d(lua_State *L)
        return 1;
 }
 
+/***
+ * @function kann.layer.max1d(in, kern_size, stride_size, pad_size[, flags])
+ * Creates 1D max pooling layer (for use after conv1d)
+ * @param {kann_node} in kann node (must be 3D: NCW)
+ * @param {int} kern_size kernel size (use width for global pooling)
+ * @param {int} stride_size stride
+ * @param {int} pad_size padding
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_max1d(lua_State *L)
+{
+       kad_node_t *in = lua_check_kann_node(L, 1);
+       int k_size = luaL_checkinteger(L, 2);
+       int stride = luaL_checkinteger(L, 3);
+       int pad = luaL_checkinteger(L, 4);
+
+       if (in != NULL) {
+               kad_node_t *t;
+               t = kad_max1d(in, k_size, stride, pad);
+
+               if (t == NULL) {
+                       return luaL_error(L, "max1d requires 3D (NCW) input");
+               }
+
+               PROCESS_KAD_FLAGS(t, 5);
+               PUSH_KAD_NODE(t);
+       }
+       else {
+               return luaL_error(L, "invalid arguments, input, k, stride, pad required");
+       }
+
+       return 1;
+}
+
+/***
+ * @function kann.layer.input3d(channels, width[, flags])
+ * Creates a 3D input layer in NCW format (for conv1d networks)
+ * @param {int} channels number of channels (e.g. embedding dimension)
+ * @param {int} width sequence length (e.g. max words)
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_input3d(lua_State *L)
+{
+       int channels = luaL_checkinteger(L, 1);
+       int width = luaL_checkinteger(L, 2);
+
+       if (channels > 0 && width > 0) {
+               kad_node_t *t;
+
+               t = kad_feed(3, 1, channels, width);
+               t->ext_flag |= KANN_F_IN;
+
+               PROCESS_KAD_FLAGS(t, 3);
+               PUSH_KAD_NODE(t);
+       }
+       else {
+               return luaL_error(L, "invalid arguments, channels and width required");
+       }
+
+       return 1;
+}
+
 /***
  * @function kann.layer.cost(in, nout, cost_type[, flags])
- * Creates 1D convolution layer
+ * Creates cost layer
  * @param {kann_node} in kann node
  * @param {int} nout number of outputs
  * @param {int} cost_type see kann.cost table