]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Fasttext embed: SIF word weighting for sentence vectors
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 22 Feb 2026 18:11:44 +0000 (18:11 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 22 Feb 2026 18:11:44 +0000 (18:11 +0000)
Add Smooth Inverse Frequency (SIF) weighting to the fasttext embedding
provider. Common words (the, is, a) get near-zero weight while
distinctive words (viagra, invoice) get high weight, significantly
improving embedding quality without changing dimensionality.

Expose get_word_frequency() from the fasttext shim C++ API and Lua
bindings, returning p(word) = count/ntokens from the model vocabulary.

SIF is enabled by default (sif_weight=true, sif_a=1e-3). Combined with
multi-model mean+max pooling, improves F1 from 0.87 to 0.90 in testing.

lualib/plugins/neural/providers/fasttext_embed.lua
src/libserver/fasttext/fasttext_shim.cxx
src/libserver/fasttext/fasttext_shim.h
src/lua/lua_fasttext.cxx

index 860a7c48a8a5947a6d4ce0f6d80e6159ad802666..a38db985425c60dc80ee04da77928e2de5d5d9ba 100644 (file)
@@ -12,6 +12,10 @@ Pooling modes (pooling = "mean_max" by default):
   "mean"     - average of word vectors (classic fasttext sentence vector)
   "mean_max" - concatenation of mean and element-wise max pooling
 
+SIF (Smooth Inverse Frequency) weighting is enabled by default (sif_weight = true).
+Common words (the, is, a) get near-zero weight, distinctive words get high weight.
+Tune with sif_a parameter (default 1e-3). Set sif_weight = false to disable.
+
 Configuration example in neural.conf:
   providers = [
     {
@@ -25,6 +29,8 @@ Configuration example in neural.conf:
       weight = 1.0;
       multi_model = true;    # use all language models (default)
       pooling = "mean_max";  # mean + max pooling (default)
+      sif_weight = true;     # SIF word weighting (default)
+      sif_a = 1e-3;          # SIF smoothing parameter (default)
     }
   ];
 ]] --
@@ -170,11 +176,16 @@ local function extract_words(task, opts)
 end
 
 -- Compute mean and optionally max pooling from word vectors
-local function compute_pooled_vectors(model, words, pooling)
+-- When sif_a > 0, uses SIF (Smooth Inverse Frequency) weighting:
+--   w(word) = a / (a + p(word))
+-- where p(word) is the word probability from the model's vocabulary.
+-- Common words get near-zero weight, distinctive words get high weight.
+local function compute_pooled_vectors(model, words, pooling, sif_a)
   local dim = model:get_dimension()
   local mean_vec = {}
   local max_vec = {}
   local need_max = (pooling == 'mean_max')
+  local use_sif = sif_a and sif_a > 0
 
   for d = 1, dim do
     mean_vec[d] = 0.0
@@ -183,13 +194,22 @@ local function compute_pooled_vectors(model, words, pooling)
     end
   end
 
-  local count = 0
+  local total_weight = 0.0
   for _, w in ipairs(words) do
     local wv = model:get_word_vector(w)
     if wv and #wv >= dim then
-      count = count + 1
+      -- SIF weight: a / (a + p(word)); unknown words get weight 1.0
+      local weight = 1.0
+      if use_sif then
+        local freq = model:get_word_frequency(w)
+        if freq > 0 then
+          weight = sif_a / (sif_a + freq)
+        end
+      end
+
+      total_weight = total_weight + weight
       for d = 1, dim do
-        mean_vec[d] = mean_vec[d] + wv[d]
+        mean_vec[d] = mean_vec[d] + wv[d] * weight
         if need_max and wv[d] > max_vec[d] then
           max_vec[d] = wv[d]
         end
@@ -197,13 +217,13 @@ local function compute_pooled_vectors(model, words, pooling)
     end
   end
 
-  if count == 0 then
+  if total_weight == 0 then
     return nil
   end
 
-  -- Normalize mean
+  -- Normalize weighted mean
   for d = 1, dim do
-    mean_vec[d] = mean_vec[d] / count
+    mean_vec[d] = mean_vec[d] / total_weight
   end
 
   -- L2-normalize mean vector (match fasttext behavior)
@@ -299,12 +319,17 @@ neural_common.register_provider('fasttext_embed', {
     end
 
     local pooling = pcfg.pooling or 'mean_max'
+    -- SIF weighting: enabled by default with a=1e-3
+    local sif_a = pcfg.sif_a
+    if sif_a == nil then
+      sif_a = (pcfg.sif_weight ~= false) and 1e-3 or 0
+    end
     local combined_vec = {}
     local model_names = {}
     local total_dim = 0
 
     for _, m in ipairs(models) do
-      local vec = compute_pooled_vectors(m.model, words, pooling)
+      local vec = compute_pooled_vectors(m.model, words, pooling, sif_a)
       if vec then
         for _, v in ipairs(vec) do
           combined_vec[#combined_vec + 1] = v
index 380085e4276dad6d6ce8186e3ca15556bc3ed01e..b6a7b6da484d7d24672861da1fd09156d8a3439d 100644 (file)
@@ -1127,4 +1127,21 @@ auto fasttext_model::get_ntokens() const -> std::int64_t
        return impl_->dict.get_ntokens();
 }
 
+auto fasttext_model::get_word_frequency(std::string_view word) const -> double
+{
+       auto id = impl_->dict.find(word);
+       if (id < 0) {
+               return 0.0;
+       }
+       auto *entry = impl_->dict.get_entry(id);
+       if (!entry) {
+               return 0.0;
+       }
+       auto ntokens = impl_->dict.get_ntokens();
+       if (ntokens <= 0) {
+               return 0.0;
+       }
+       return static_cast<double>(entry->count) / static_cast<double>(ntokens);
+}
+
 } /* namespace rspamd::fasttext */
index abf47fb61828b76c349f844ae08d4cfce903fdba..45425ba18aa599ffc7dd9e3cd0bf8a8e85d4abeb 100644 (file)
@@ -140,6 +140,12 @@ public:
         */
        auto get_ntokens() const -> std::int64_t;
 
+       /**
+        * Get word probability p(word) = count(word) / ntokens.
+        * Returns 0.0 for unknown words.
+        */
+       auto get_word_frequency(std::string_view word) const -> double;
+
 private:
        explicit fasttext_model(std::unique_ptr<fasttext_model_impl> impl);
        std::unique_ptr<fasttext_model_impl> impl_;
index f44e85ed096c2659ea88b501b004f2f2f070d419..e2f991624d2ba0201cac4ddc16b3a46282ea220d 100644 (file)
@@ -48,6 +48,7 @@ static int lua_fasttext_model_get_dimension(lua_State *L);
 static int lua_fasttext_model_get_sentence_vector(lua_State *L);
 static int lua_fasttext_model_get_word_vector(lua_State *L);
 static int lua_fasttext_model_predict(lua_State *L);
+static int lua_fasttext_model_get_word_frequency(lua_State *L);
 static int lua_fasttext_model_dtor(lua_State *L);
 static int lua_fasttext_model_is_loaded(lua_State *L);
 
@@ -62,6 +63,7 @@ static const struct luaL_reg fasttextlib_m[] = {
        {"get_dimension", lua_fasttext_model_get_dimension},
        {"get_sentence_vector", lua_fasttext_model_get_sentence_vector},
        {"get_word_vector", lua_fasttext_model_get_word_vector},
+       {"get_word_frequency", lua_fasttext_model_get_word_frequency},
        {"predict", lua_fasttext_model_predict},
        {"is_loaded", lua_fasttext_model_is_loaded},
        {"__gc", lua_fasttext_model_dtor},
@@ -155,6 +157,30 @@ lua_fasttext_model_get_dimension(lua_State *L)
        return 1;
 }
 
+/***
+ * @method model:get_word_frequency(word)
+ * Get word probability p(word) = count(word) / total_tokens.
+ * Useful for SIF (Smooth Inverse Frequency) sentence weighting.
+ * @param {string} word input word
+ * @return {number} word probability (0..1), 0 for unknown words
+ */
+static int
+lua_fasttext_model_get_word_frequency(lua_State *L)
+{
+       auto *model = lua_check_fasttext_model(L, 1);
+       const char *word = luaL_checkstring(L, 2);
+
+       if (!model || !model->loaded) {
+               lua_pushnumber(L, 0.0);
+               return 1;
+       }
+
+       auto freq = model->model->get_word_frequency(std::string_view{word});
+       lua_pushnumber(L, freq);
+
+       return 1;
+}
+
 /***
  * @method model:get_word_vector(word)
  * Get embedding vector for a single word