]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Wire Lua rspamd_fasttext through maps infrastructure 5909/head
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 26 Feb 2026 12:54:05 +0000 (12:54 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 26 Feb 2026 12:54:05 +0000 (12:54 +0000)
Add load_map(cfg, path) to rspamd_fasttext module that loads FastText
models via the maps infrastructure (HTTP URLs + file with shared mmap).
The fasttext_embed neural provider now registers models as maps at
config time via a new init callback, enabling shared memory across
workers and automatic reload on map updates.

lualib/plugins/neural.lua
lualib/plugins/neural/providers/fasttext_embed.lua
src/lua/lua_fasttext.cxx
src/plugins/lua/neural.lua

index ab8c49652c0feae5d193a156c0e35331aeaa4171..5d4ee4e2de9bca882b5497b9febba34425253e2c 100644 (file)
@@ -1490,6 +1490,7 @@ return {
   build_providers_meta = build_providers_meta,
   apply_normalization = apply_normalization,
   gen_unlock_cb = gen_unlock_cb,
+  get_provider = get_provider,
   get_rule_settings = get_rule_settings,
   load_scripts = load_scripts,
   module_config = module_config,
index d15d3940a8b646e8e2974510edb949f6cfa9dbd5..b56d270f98fd209f9f51d8ffbf141dd501ed3edb 100644 (file)
@@ -67,21 +67,28 @@ local function load_model(path)
   end
 
   if loaded_models[path] then
-    return loaded_models[path]
+    local cached = loaded_models[path]
+    if cached:is_loaded() then
+      return cached
+    end
+    -- Cached but map not ready yet
+    return nil
   end
 
+  -- Use load_map for map-backed loading (supports HTTP URLs + file reload)
   rspamd_logger.infox(rspamd_config, '%s: loading FastText model from %s', N, path)
-  local model = rspamd_fasttext.load(path)
+  local model = rspamd_fasttext.load_map(rspamd_config, path)
 
-  if model and model:is_loaded() then
-    rspamd_logger.infox(rspamd_config, '%s: loaded FastText model %s, dimension=%s',
-      N, path, model:get_dimension())
+  if model then
     loaded_models[path] = model
-    return model
-  else
-    rspamd_logger.errx(rspamd_config, '%s: failed to load FastText model from %s', N, path)
-    return nil
+    if model:is_loaded() then
+      rspamd_logger.infox(rspamd_config, '%s: loaded FastText model %s, dimension=%s',
+        N, path, model:get_dimension())
+      return model
+    end
   end
+
+  return nil
 end
 
 -- Collect all available models (for multi_model mode)
@@ -414,6 +421,17 @@ local function compute_conv1d_features(models, words, max_words, sif_a, opts)
 end
 
 neural_common.register_provider('fasttext_embed', {
+  init = function(pcfg)
+    -- Pre-register map-backed models at config time
+    if pcfg.model then
+      load_model(pcfg.model)
+    end
+    if pcfg.language_models then
+      for _, path in pairs(pcfg.language_models) do
+        load_model(path)
+      end
+    end
+  end,
   collect_async = function(task, ctx, cont)
     local pcfg = ctx.config or {}
 
index e2f991624d2ba0201cac4ddc16b3a46282ea220d..27264366775c8fbb195ca8af737337c4b423e0b7 100644 (file)
@@ -35,6 +35,9 @@
  */
 
 #include "fasttext_shim.h"
+#include "libserver/cfg_file.h"
+#include "libserver/maps/map.h"
+#include "libserver/maps/map_private.h"
 #include <string>
 #include <vector>
 #include <cmath>
@@ -44,6 +47,7 @@
 
 /* Forward declarations */
 static int lua_fasttext_load(lua_State *L);
+static int lua_fasttext_load_map(lua_State *L);
 static int lua_fasttext_model_get_dimension(lua_State *L);
 static int lua_fasttext_model_get_sentence_vector(lua_State *L);
 static int lua_fasttext_model_get_word_vector(lua_State *L);
@@ -55,6 +59,7 @@ static int lua_fasttext_model_is_loaded(lua_State *L);
 /* Module functions */
 static const struct luaL_reg fasttextlib_f[] = {
        {"load", lua_fasttext_load},
+       {"load_map", lua_fasttext_load_map},
        {nullptr, nullptr},
 };
 
@@ -71,11 +76,32 @@ static const struct luaL_reg fasttextlib_m[] = {
        {nullptr, nullptr},
 };
 
+/**
+ * Map callback data for fasttext model loading via maps infrastructure.
+ * Same pattern as lang_detection_fasttext.cxx.
+ */
+struct fasttext_map_data {
+       rspamd::fasttext::fasttext_model *model = nullptr;
+};
+
 struct rspamd_lua_fasttext_model {
-       rspamd::fasttext::fasttext_model *model;
-       bool loaded;
+       rspamd::fasttext::fasttext_model *owned_model; /* non-null for direct load */
+       void **map_target;                             /* non-null for map-backed load */
 };
 
+static rspamd::fasttext::fasttext_model *
+lua_fasttext_get_model(struct rspamd_lua_fasttext_model *wrap)
+{
+       if (wrap->owned_model) {
+               return wrap->owned_model;
+       }
+       if (wrap->map_target && *wrap->map_target) {
+               auto *fdata = static_cast<fasttext_map_data *>(*wrap->map_target);
+               return fdata->model;
+       }
+       return nullptr;
+}
+
 static struct rspamd_lua_fasttext_model *
 lua_check_fasttext_model(lua_State *L, int pos)
 {
@@ -85,9 +111,80 @@ lua_check_fasttext_model(lua_State *L, int pos)
        return *pmodel;
 }
 
+/* Map read callback: receives filename, loads model */
+static char *
+lua_fasttext_map_read_cb(char *chunk, int len,
+                                                struct map_cb_data *data, gboolean final)
+{
+       if (data->cur_data == nullptr) {
+               data->cur_data = new fasttext_map_data();
+       }
+
+       if (!final) {
+               return chunk + len;
+       }
+
+       auto *fdata = static_cast<fasttext_map_data *>(data->cur_data);
+       auto *map = data->map;
+       auto fname = std::string{chunk, static_cast<std::size_t>(len)};
+       auto offset = static_cast<std::int64_t>(
+               rspamd_map_get_no_file_read_offset(data->map));
+
+       auto result = rspamd::fasttext::fasttext_model::load(fname, offset);
+       if (result) {
+               fdata->model = new rspamd::fasttext::fasttext_model(std::move(*result));
+               msg_info_map("loaded fasttext model from %s (offset %z)",
+                                        fname.c_str(), (gsize) offset);
+       }
+       else {
+               msg_err_map("cannot load fasttext model from %s (offset %z): %s",
+                                       fname.c_str(), (gsize) offset,
+                                       result.error().error_message.data());
+       }
+
+       return chunk + len;
+}
+
+/* Map fin callback: swap old model for new one */
+static void
+lua_fasttext_map_fin_cb(struct map_cb_data *data, void **target)
+{
+       auto *new_data = static_cast<fasttext_map_data *>(data->cur_data);
+       auto *old_data = static_cast<fasttext_map_data *>(data->prev_data);
+
+       if (data->errored) {
+               if (new_data) {
+                       delete new_data->model;
+                       delete new_data;
+                       data->cur_data = nullptr;
+               }
+               return;
+       }
+
+       if (target) {
+               *target = data->cur_data;
+       }
+
+       if (old_data) {
+               delete old_data->model;
+               delete old_data;
+       }
+}
+
+/* Map destructor callback */
+static void
+lua_fasttext_map_dtor_cb(struct map_cb_data *data)
+{
+       auto *fdata = static_cast<fasttext_map_data *>(data->cur_data);
+       if (fdata) {
+               delete fdata->model;
+               delete fdata;
+       }
+}
+
 /***
  * @function rspamd_fasttext.load(path)
- * Load a FastText model from file
+ * Load a FastText model from file (direct synchronous load)
  * @param {string} path path to the .bin model file
  * @return {rspamd_fasttext} model object (check is_loaded())
  */
@@ -97,8 +194,8 @@ lua_fasttext_load(lua_State *L)
        const char *path = luaL_checkstring(L, 1);
 
        auto *model = new rspamd_lua_fasttext_model();
-       model->model = nullptr;
-       model->loaded = false;
+       model->owned_model = nullptr;
+       model->map_target = nullptr;
 
        /* Store pointer in userdata */
        auto **pmodel = static_cast<struct rspamd_lua_fasttext_model **>(
@@ -114,8 +211,7 @@ lua_fasttext_load(lua_State *L)
 
        auto result = rspamd::fasttext::fasttext_model::load(path);
        if (result) {
-               model->model = new rspamd::fasttext::fasttext_model(std::move(*result));
-               model->loaded = true;
+               model->owned_model = new rspamd::fasttext::fasttext_model(std::move(*result));
        }
        else {
                msg_err("fasttext model '%s' failed to load: %s", path,
@@ -125,6 +221,73 @@ lua_fasttext_load(lua_State *L)
        return 1;
 }
 
+/***
+ * @function rspamd_fasttext.load_map(cfg, path)
+ * Load a FastText model via maps infrastructure (supports HTTP URLs + file reload).
+ * For plain file paths, falls back to direct loading.
+ * Must be called at config time (not during task processing).
+ * @param {rspamd_config} cfg configuration object
+ * @param {string} path path or map URL to the .bin model file
+ * @return {rspamd_fasttext} model object (check is_loaded())
+ */
+static int
+lua_fasttext_load_map(lua_State *L)
+{
+       auto *cfg = lua_check_config(L, 1);
+       const char *path = luaL_checkstring(L, 2);
+
+       if (!cfg) {
+               return luaL_error(L, "invalid config argument");
+       }
+
+       auto *model = new rspamd_lua_fasttext_model();
+       model->owned_model = nullptr;
+       model->map_target = nullptr;
+
+       /* Store pointer in userdata */
+       auto **pmodel = static_cast<struct rspamd_lua_fasttext_model **>(
+               lua_newuserdata(L, sizeof(struct rspamd_lua_fasttext_model *)));
+       *pmodel = model;
+       rspamd_lua_setclass(L, FASTTEXT_MODEL_CLASS, -1);
+
+       if (rspamd_map_is_map(path)) {
+               /* Map-backed load: allocate target on config mempool */
+               model->map_target = static_cast<void **>(
+                       rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(void *)));
+
+               auto *map = rspamd_map_add(cfg, path,
+                                                                  "fasttext model",
+                                                                  lua_fasttext_map_read_cb,
+                                                                  lua_fasttext_map_fin_cb,
+                                                                  lua_fasttext_map_dtor_cb,
+                                                                  model->map_target,
+                                                                  nullptr,
+                                                                  RSPAMD_MAP_FILE_NO_READ);
+
+               if (!map) {
+                       msg_err_config("cannot add map for fasttext model '%s'", path);
+               }
+       }
+       else {
+               /* Direct file load (same as load()) */
+               if (access(path, R_OK) != 0) {
+                       msg_err_config("fasttext model '%s' is not readable: %s", path, strerror(errno));
+                       return 1;
+               }
+
+               auto result = rspamd::fasttext::fasttext_model::load(path);
+               if (result) {
+                       model->owned_model = new rspamd::fasttext::fasttext_model(std::move(*result));
+               }
+               else {
+                       msg_err_config("fasttext model '%s' failed to load: %s", path,
+                                                  result.error().error_message.data());
+               }
+       }
+
+       return 1;
+}
+
 /***
  * @method model:is_loaded()
  * Check if the model was loaded successfully
@@ -134,7 +297,7 @@ static int
 lua_fasttext_model_is_loaded(lua_State *L)
 {
        auto *model = lua_check_fasttext_model(L, 1);
-       lua_pushboolean(L, model && model->loaded);
+       lua_pushboolean(L, model && lua_fasttext_get_model(model) != nullptr);
        return 1;
 }
 
@@ -146,14 +309,15 @@ lua_fasttext_model_is_loaded(lua_State *L)
 static int
 lua_fasttext_model_get_dimension(lua_State *L)
 {
-       auto *model = lua_check_fasttext_model(L, 1);
+       auto *wrap = lua_check_fasttext_model(L, 1);
+       auto *ft = wrap ? lua_fasttext_get_model(wrap) : nullptr;
 
-       if (!model || !model->loaded) {
+       if (!ft) {
                lua_pushinteger(L, 0);
                return 1;
        }
 
-       lua_pushinteger(L, model->model->get_dimension());
+       lua_pushinteger(L, ft->get_dimension());
        return 1;
 }
 
@@ -167,15 +331,16 @@ lua_fasttext_model_get_dimension(lua_State *L)
 static int
 lua_fasttext_model_get_word_frequency(lua_State *L)
 {
-       auto *model = lua_check_fasttext_model(L, 1);
+       auto *wrap = lua_check_fasttext_model(L, 1);
        const char *word = luaL_checkstring(L, 2);
+       auto *ft = wrap ? lua_fasttext_get_model(wrap) : nullptr;
 
-       if (!model || !model->loaded) {
+       if (!ft) {
                lua_pushnumber(L, 0.0);
                return 1;
        }
 
-       auto freq = model->model->get_word_frequency(std::string_view{word});
+       auto freq = ft->get_word_frequency(std::string_view{word});
        lua_pushnumber(L, freq);
 
        return 1;
@@ -190,17 +355,18 @@ lua_fasttext_model_get_word_frequency(lua_State *L)
 static int
 lua_fasttext_model_get_word_vector(lua_State *L)
 {
-       auto *model = lua_check_fasttext_model(L, 1);
+       auto *wrap = lua_check_fasttext_model(L, 1);
        const char *word = luaL_checkstring(L, 2);
+       auto *ft = wrap ? lua_fasttext_get_model(wrap) : nullptr;
 
-       if (!model || !model->loaded) {
+       if (!ft) {
                lua_pushnil(L);
                return 1;
        }
 
        std::vector<float> vec;
 
-       model->model->get_word_vector(vec, std::string_view{word});
+       ft->get_word_vector(vec, std::string_view{word});
 
        auto vec_size = static_cast<std::int32_t>(vec.size());
        lua_createtable(L, vec_size, 0);
@@ -223,16 +389,17 @@ lua_fasttext_model_get_word_vector(lua_State *L)
 static int
 lua_fasttext_model_get_sentence_vector(lua_State *L)
 {
-       auto *model = lua_check_fasttext_model(L, 1);
+       auto *wrap = lua_check_fasttext_model(L, 1);
+       auto *ft = wrap ? lua_fasttext_get_model(wrap) : nullptr;
 
-       if (!model || !model->loaded) {
+       if (!ft) {
                lua_pushnil(L);
                return 1;
        }
 
        luaL_argcheck(L, lua_istable(L, 2), 2, "'table' of words expected");
 
-       auto dim = model->model->get_dimension();
+       auto dim = ft->get_dimension();
        if (dim <= 0 || dim > 4096) {
                lua_pushnil(L);
                return 1;
@@ -251,7 +418,7 @@ lua_fasttext_model_get_sentence_vector(lua_State *L)
                        std::size_t len;
                        const char *w = lua_tolstring(L, -1, &len);
                        if (len > 0) {
-                               model->model->get_word_vector(word_vec, std::string_view{w, len});
+                               ft->get_word_vector(word_vec, std::string_view{w, len});
                                auto wv_size = std::min(dim, static_cast<std::int32_t>(word_vec.size()));
                                for (std::int32_t d = 0; d < wv_size; d++) {
                                        sentence_vec[d] += word_vec[d];
@@ -307,9 +474,10 @@ lua_fasttext_model_get_sentence_vector(lua_State *L)
 static int
 lua_fasttext_model_predict(lua_State *L)
 {
-       auto *model = lua_check_fasttext_model(L, 1);
+       auto *wrap = lua_check_fasttext_model(L, 1);
+       auto *ft = wrap ? lua_fasttext_get_model(wrap) : nullptr;
 
-       if (!model || !model->loaded) {
+       if (!ft) {
                lua_pushnil(L);
                return 1;
        }
@@ -327,7 +495,7 @@ lua_fasttext_model_predict(lua_State *L)
                        std::size_t len;
                        const char *w = lua_tolstring(L, -1, &len);
                        if (len > 0) {
-                               model->model->word2vec(std::string_view{w, len}, word_ids);
+                               ft->word2vec(std::string_view{w, len}, word_ids);
                        }
                }
                lua_pop(L, 1);
@@ -339,7 +507,7 @@ lua_fasttext_model_predict(lua_State *L)
        }
 
        std::vector<rspamd::fasttext::prediction> preds;
-       model->model->predict(k, word_ids, preds, 0.0f);
+       ft->predict(k, word_ids, preds, 0.0f);
 
        lua_createtable(L, static_cast<int>(preds.size()), 0);
        for (std::size_t i = 0; i < preds.size(); i++) {
@@ -361,7 +529,9 @@ lua_fasttext_model_dtor(lua_State *L)
                rspamd_lua_check_udata(L, 1, FASTTEXT_MODEL_CLASS));
 
        if (pmodel && *pmodel) {
-               delete (*pmodel)->model;
+               /* Only delete owned_model; map_target is mempool-allocated and
+                * the map infrastructure owns the model data */
+               delete (*pmodel)->owned_model;
                delete *pmodel;
                *pmodel = nullptr;
        }
index 7166140a659f21292d1bef9f4aabe912bcb42632..9994d411a24c395afd9a278edb06d5f3aa47fbf5 100644 (file)
@@ -1174,7 +1174,7 @@ for k, r in pairs(rules) do
     rule_elt.max_inputs = nil
   end
 
-  -- Phase 4: basic provider config validation
+  -- Phase 4: basic provider config validation + init
   if rule_elt.providers and #rule_elt.providers > 0 then
     for i, pcfg in ipairs(rule_elt.providers) do
       if not (pcfg.type or pcfg.name) then
@@ -1184,6 +1184,11 @@ for k, r in pairs(rules) do
         rspamd_logger.errx(rspamd_config,
           'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
       end
+      -- Call provider init at config time (for map registration etc.)
+      local prov = neural_common.get_provider(pcfg.type or pcfg.name)
+      if prov and prov.init then
+        prov.init(pcfg)
+      end
     end
   end