]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Fasttext embed: multi-model and mean+max pooling 5897/head
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 22 Feb 2026 11:07:49 +0000 (11:07 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 22 Feb 2026 11:07:49 +0000 (11:07 +0000)
Use all configured language_models for every message by default
(multi_model=true), concatenating vectors from each model for
richer cross-lingual representations.

Add mean+max pooling (pooling="mean_max" default) which concatenates
the average word vector with element-wise max pooling, capturing both
typical and prominent semantic features.

With 2 quantized 50-dim models this produces 200-dim vectors instead
of 50, significantly improving classification (F1 0.51 -> 0.87 in
testing).

lualib/plugins/neural/providers/fasttext_embed.lua

index 6627c8739cd24e52c37d1163c5f04c04a65f7f60..860a7c48a8a5947a6d4ce0f6d80e6159ad802666 100644 (file)
@@ -4,6 +4,14 @@ Loads a FastText model (supervised or unsupervised) and computes sentence
 embeddings from message text. Supports per-language models for multilingual
 deployments.
 
+By default, all configured language_models are used for every message
+(multi_model = true), producing richer cross-lingual representations.
+Set multi_model = false to select a single model based on detected language.
+
+Pooling modes (pooling = "mean_max" by default):
+  "mean"     - average of word vectors (classic fasttext sentence vector)
+  "mean_max" - concatenation of mean and element-wise max pooling
+
 Configuration example in neural.conf:
   providers = [
     {
@@ -15,6 +23,8 @@ Configuration example in neural.conf:
         ru = "/path/to/ru_model.bin";
       };
       weight = 1.0;
+      multi_model = true;    # use all language models (default)
+      pooling = "mean_max";  # mean + max pooling (default)
     }
   ];
 ]] --
@@ -58,19 +68,48 @@ local function load_model(path)
   end
 end
 
--- Detect primary language from the displayed text part
-local function detect_language(task)
-  local part = lua_mime.get_displayed_text_part(task)
-  if part then
-    local lang = part:get_language()
-    if lang and lang ~= '' then
-      return lang
+-- Collect all available models (for multi_model mode)
+local function collect_all_models(pcfg)
+  local models = {}
+
+  if pcfg.language_models then
+    -- Sort by language key for deterministic order
+    local langs = {}
+    for lang, _ in pairs(pcfg.language_models) do
+      langs[#langs + 1] = lang
+    end
+    table.sort(langs)
+
+    for _, lang in ipairs(langs) do
+      local path = pcfg.language_models[lang]
+      local model = load_model(path)
+      if model then
+        models[#models + 1] = { model = model, path = path, lang = lang }
+      end
     end
   end
-  return nil
+
+  -- Add default model if configured and not already loaded
+  if pcfg.model then
+    local already_loaded = false
+    for _, m in ipairs(models) do
+      if m.path == pcfg.model then
+        already_loaded = true
+        break
+      end
+    end
+    if not already_loaded then
+      local model = load_model(pcfg.model)
+      if model then
+        models[#models + 1] = { model = model, path = pcfg.model, lang = 'default' }
+      end
+    end
+  end
+
+  return models
 end
 
--- Select the appropriate model based on language
+-- Select a single model based on language (for single-model mode)
 local function select_model(pcfg, language)
   -- Check per-language models first
   if language and pcfg.language_models then
@@ -78,7 +117,7 @@ local function select_model(pcfg, language)
     if lang_path then
       local model = load_model(lang_path)
       if model then
-        return model, lang_path
+        return { { model = model, path = lang_path, lang = language } }
       end
     end
   end
@@ -87,11 +126,11 @@ local function select_model(pcfg, language)
   if pcfg.model then
     local model = load_model(pcfg.model)
     if model then
-      return model, pcfg.model
+      return { { model = model, path = pcfg.model, lang = 'default' } }
     end
   end
 
-  return nil, nil
+  return {}
 end
 
 -- Extract words from text parts
@@ -130,6 +169,77 @@ local function extract_words(task, opts)
   return words
 end
 
+-- Compute mean and optionally max pooling from word vectors
+local function compute_pooled_vectors(model, words, pooling)
+  local dim = model:get_dimension()
+  local mean_vec = {}
+  local max_vec = {}
+  local need_max = (pooling == 'mean_max')
+
+  for d = 1, dim do
+    mean_vec[d] = 0.0
+    if need_max then
+      max_vec[d] = -math.huge
+    end
+  end
+
+  local count = 0
+  for _, w in ipairs(words) do
+    local wv = model:get_word_vector(w)
+    if wv and #wv >= dim then
+      count = count + 1
+      for d = 1, dim do
+        mean_vec[d] = mean_vec[d] + wv[d]
+        if need_max and wv[d] > max_vec[d] then
+          max_vec[d] = wv[d]
+        end
+      end
+    end
+  end
+
+  if count == 0 then
+    return nil
+  end
+
+  -- Normalize mean
+  for d = 1, dim do
+    mean_vec[d] = mean_vec[d] / count
+  end
+
+  -- L2-normalize mean vector (match fasttext behavior)
+  local norm = 0.0
+  for d = 1, dim do
+    norm = norm + mean_vec[d] * mean_vec[d]
+  end
+  norm = math.sqrt(norm)
+  if norm > 0 then
+    for d = 1, dim do
+      mean_vec[d] = mean_vec[d] / norm
+    end
+  end
+
+  if need_max then
+    -- L2-normalize max vector
+    norm = 0.0
+    for d = 1, dim do
+      norm = norm + max_vec[d] * max_vec[d]
+    end
+    norm = math.sqrt(norm)
+    if norm > 0 then
+      for d = 1, dim do
+        max_vec[d] = max_vec[d] / norm
+      end
+    end
+
+    -- Concatenate mean + max
+    for d = 1, dim do
+      mean_vec[dim + d] = max_vec[d]
+    end
+  end
+
+  return mean_vec
+end
+
 neural_common.register_provider('fasttext_embed', {
   collect_async = function(task, ctx, cont)
     local pcfg = ctx.config or {}
@@ -140,11 +250,27 @@ neural_common.register_provider('fasttext_embed', {
       return
     end
 
-    local language = detect_language(task)
-    local model, model_path = select_model(pcfg, language)
+    -- Select models: all models or single based on language
+    local multi_model = pcfg.multi_model ~= false -- default true
+    local models
+    if multi_model and pcfg.language_models then
+      models = collect_all_models(pcfg)
+    else
+      local language = task and (function()
+        local part = lua_mime.get_displayed_text_part(task)
+        if part then
+          local lang = part:get_language()
+          if lang and lang ~= '' then
+            return lang
+          end
+        end
+        return nil
+      end)()
+      models = select_model(pcfg, language)
+    end
 
-    if not model then
-      rspamd_logger.debugm(N, task, 'fasttext_embed: no model available; skip')
+    if #models == 0 then
+      rspamd_logger.debugm(N, task, 'fasttext_embed: no models available; skip')
       cont(nil)
       return
     end
@@ -172,14 +298,24 @@ neural_common.register_provider('fasttext_embed', {
       return
     end
 
-    local dim = model:get_dimension()
-    rspamd_logger.debugm(N, task, 'fasttext_embed: computing %s-dim vector from %s words (lang=%s, model=%s)',
-      dim, #words, language or 'unknown', model_path)
+    local pooling = pcfg.pooling or 'mean_max'
+    local combined_vec = {}
+    local model_names = {}
+    local total_dim = 0
 
-    local vec = model:get_sentence_vector(words)
+    for _, m in ipairs(models) do
+      local vec = compute_pooled_vectors(m.model, words, pooling)
+      if vec then
+        for _, v in ipairs(vec) do
+          combined_vec[#combined_vec + 1] = v
+        end
+        total_dim = total_dim + #vec
+        model_names[#model_names + 1] = m.lang
+      end
+    end
 
-    if not vec or #vec == 0 then
-      rspamd_logger.debugm(N, task, 'fasttext_embed: empty vector; skip')
+    if #combined_vec == 0 then
+      rspamd_logger.debugm(N, task, 'fasttext_embed: empty vectors from all models; skip')
       cont(nil)
       return
     end
@@ -187,13 +323,14 @@ neural_common.register_provider('fasttext_embed', {
     local meta = {
       name = pcfg.name or 'fasttext_embed',
       type = 'fasttext_embed',
-      dim = dim,
+      dim = total_dim,
       weight = ctx.weight or 1.0,
-      model = model_path,
-      language = language,
+      models = table.concat(model_names, '+'),
+      pooling = pooling,
     }
 
-    rspamd_logger.debugm(N, task, 'fasttext_embed: produced %s-dim vector', #vec)
-    cont(vec, meta)
+    rspamd_logger.debugm(N, task, 'fasttext_embed: produced %s-dim vector (%s models, %s pooling, %s words)',
+      total_dim, #models, pooling, #words)
+    cont(combined_vec, meta)
   end,
 })