]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Fasttext shim: fix quantized model support and add hierarchical softmax
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 21 Feb 2026 18:54:10 +0000 (18:54 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 21 Feb 2026 18:54:10 +0000 (18:54 +0000)
- Add pruneidx support in dictionary subword computation: pruned models
  (like lid.176.ftz) remap subword hash IDs through the prune index table,
  matching FastText's pushHash() behavior
- Apply per-row norm from quantized norm PQ (NPQ) when decoding QMatrix
  rows, fixing vector magnitude for quantized models
- Implement hierarchical softmax prediction via Huffman tree DFS,
  matching FastText's HierarchicalSoftmaxLoss for models with loss=hs
- Add predict() method to Lua rspamd_fasttext bindings

src/libserver/fasttext/fasttext_shim.cxx
src/lua/lua_fasttext.cxx

index 74cac134e19dc446c7415a8f198111beeb7aca7b..380085e4276dad6d6ce8186e3ca15556bc3ed01e 100644 (file)
@@ -220,7 +220,7 @@ public:
                return true;
        }
 
-       void add_code(const std::uint8_t *codes, float *vec, std::int32_t dim) const
+       void add_code(const std::uint8_t *codes, float *vec, std::int32_t dim, float alpha) const
        {
                if (centroids_.empty()) return;
                std::int32_t offset = 0;
@@ -232,14 +232,14 @@ public:
 
                        for (std::int32_t d = 0; d < sub_dim; d++) {
                                if (centroid_base + d < centroids_.size() && offset + d < dim) {
-                                       vec[offset + d] += centroids_[centroid_base + d];
+                                       vec[offset + d] += alpha * centroids_[centroid_base + d];
                                }
                        }
                        offset += sub_dim;
                }
        }
 
-       auto dot_code(const std::uint8_t *codes, const float *vec, std::int32_t dim) const -> float
+       auto dot_code(const std::uint8_t *codes, const float *vec, std::int32_t dim, float alpha) const -> float
        {
                if (centroids_.empty()) return 0.0f;
                float result = 0.0f;
@@ -257,7 +257,7 @@ public:
                        }
                        offset += sub_dim;
                }
-               return result;
+               return result * alpha;
        }
 
        auto get_nsubq() const -> std::int32_t
@@ -265,6 +265,16 @@ public:
                return nsubq_;
        }
 
+       /* Look up a single centroid value (used for norm decoding via NPQ) */
+       auto get_centroid_value(std::int32_t sq, std::uint8_t code) const -> float
+       {
+               auto offset = get_centroid_offset(sq, static_cast<std::size_t>(code));
+               if (offset < centroids_.size()) {
+                       return centroids_[offset];
+               }
+               return 1.0f;
+       }
+
 private:
        /* Matches FastText's get_centroids: last sub-quantizer uses lastdsub stride */
        auto get_centroid_offset(std::int32_t sq, std::size_t centroid_idx) const -> std::size_t
@@ -350,12 +360,12 @@ private:
 /* Quantized matrix: codes + product quantizer */
 class quant_matrix final : public matrix_base {
 public:
-       quant_matrix(std::int64_t m, std::int64_t n,
+       quant_matrix(std::int64_t m, std::int64_t n, bool qnorm,
                                 std::vector<std::uint8_t> &&codes,
                                 std::vector<std::uint8_t> &&norm_codes,
                                 product_quantizer &&pq,
                                 product_quantizer &&npq)
-               : m_(m), n_(n),
+               : m_(m), n_(n), qnorm_(qnorm),
                  codes_(std::move(codes)),
                  norm_codes_(std::move(norm_codes)),
                  pq_(std::move(pq)),
@@ -370,7 +380,8 @@ public:
                if (nsubq <= 0) return;
                auto offset = static_cast<std::size_t>(row) * nsubq;
                if (offset + nsubq > codes_.size()) return;
-               pq_.add_code(codes_.data() + offset, vec, dim);
+               auto norm = get_norm(row);
+               pq_.add_code(codes_.data() + offset, vec, dim, norm);
        }
 
        auto dot_row(const float *vec, std::int32_t row, std::int32_t dim) const -> float override
@@ -380,7 +391,8 @@ public:
                if (nsubq <= 0) return 0.0f;
                auto offset = static_cast<std::size_t>(row) * nsubq;
                if (offset + nsubq > codes_.size()) return 0.0f;
-               return pq_.dot_code(codes_.data() + offset, vec, dim);
+               auto norm = get_norm(row);
+               return pq_.dot_code(codes_.data() + offset, vec, dim, norm);
        }
 
        auto rows() const -> std::int64_t override
@@ -393,8 +405,17 @@ public:
        }
 
 private:
+       auto get_norm(std::int32_t row) const -> float
+       {
+               if (qnorm_ && row >= 0 && static_cast<std::size_t>(row) < norm_codes_.size()) {
+                       return npq_.get_centroid_value(0, norm_codes_[row]);
+               }
+               return 1.0f;
+       }
+
        std::int64_t m_;
        std::int64_t n_;
+       bool qnorm_;
        std::vector<std::uint8_t> codes_;
        std::vector<std::uint8_t> norm_codes_;
        product_quantizer pq_;
@@ -494,8 +515,8 @@ public:
                                if (len == 1 && (i == 0 || i + 1 == ncp)) continue;
 
                                auto ngram = word.substr(positions[i], positions[i + len] - positions[i]);
-                               auto h = fnv_hash(ngram) % bucket_;
-                               ngrams.push_back(nwords_ + static_cast<std::int32_t>(h));
+                               auto h = static_cast<std::int32_t>(fnv_hash(ngram) % bucket_);
+                               push_hash(ngrams, h);
                        }
                }
        }
@@ -522,7 +543,40 @@ public:
                return bucket_;
        }
 
+       auto get_label_counts() const -> std::vector<std::int64_t>
+       {
+               std::vector<std::int64_t> counts;
+               counts.reserve(nlabels_);
+               for (std::int32_t i = nwords_; i < nwords_ + nlabels_; i++) {
+                       if (i < static_cast<std::int32_t>(entries_.size())) {
+                               counts.push_back(entries_[i].count);
+                       }
+               }
+               return counts;
+       }
+
 private:
+       /* Matches FastText's Dictionary::pushHash():
+        * - pruneidx_size_ < 0 (e.g. -1): no pruning, push directly
+        * - pruneidx_size_ == 0: all pruned, drop everything
+        * - pruneidx_size_ > 0: remap through pruneidx_ or drop if absent */
+       void push_hash(std::vector<std::int32_t> &ngrams, std::int32_t id) const
+       {
+               if (pruneidx_size_ == 0 || id < 0) {
+                       return;
+               }
+               if (pruneidx_size_ > 0) {
+                       auto it = pruneidx_.find(id);
+                       if (it != pruneidx_.end()) {
+                               id = it->second;
+                       }
+                       else {
+                               return;
+                       }
+               }
+               ngrams.push_back(nwords_ + id);
+       }
+
        std::int32_t nwords_ = 0;
        std::int32_t nlabels_ = 0;
        std::int64_t ntokens_ = 0;
@@ -538,6 +592,20 @@ private:
 };
 
 
+/* Loss function types (matches FastText's enum) */
+constexpr std::int32_t LOSS_HS = 1;
+constexpr std::int32_t LOSS_NS = 2;
+constexpr std::int32_t LOSS_SOFTMAX = 3;
+
+/* Huffman tree node for hierarchical softmax */
+struct hs_node {
+       std::int32_t parent = -1;
+       std::int32_t left = -1;
+       std::int32_t right = -1;
+       std::int64_t count = 0;
+       bool binary = false;
+};
+
 /* --- Model implementation (pimpl) --- */
 class fasttext_model_impl {
 public:
@@ -547,6 +615,8 @@ public:
        std::unique_ptr<matrix_base> output_matrix;
        /* Keep the mmap alive for the lifetime of the model */
        std::optional<rspamd::util::raii_mmaped_file> mmap_file;
+       /* Huffman tree for hierarchical softmax */
+       std::vector<hs_node> hs_tree;
 
        void word2vec(std::string_view word, std::vector<std::int32_t> &ngrams) const
        {
@@ -573,6 +643,51 @@ public:
                }
        }
 
+       /* Build Huffman tree for hierarchical softmax from label counts.
+        * Matches FastText's HierarchicalSoftmaxLoss::buildTree(). */
+       void build_hs_tree()
+       {
+               auto counts = dict.get_label_counts();
+               auto osz = static_cast<std::int32_t>(counts.size());
+               if (osz <= 0) return;
+
+               hs_tree.resize(2 * osz - 1);
+               for (auto &node: hs_tree) {
+                       node.parent = -1;
+                       node.left = -1;
+                       node.right = -1;
+                       node.count = 1e15;
+                       node.binary = false;
+               }
+
+               /* Leaves get their actual counts */
+               for (std::int32_t i = 0; i < osz; i++) {
+                       hs_tree[i].count = counts[i];
+               }
+
+               /* Build tree bottom-up */
+               std::int32_t leaf = osz - 1;
+               std::int32_t node = osz;
+               for (std::int32_t i = osz; i < 2 * osz - 1; i++) {
+                       std::int32_t mini[2] = {0, 0};
+                       for (int j = 0; j < 2; j++) {
+                               if (leaf >= 0 && hs_tree[leaf].count < hs_tree[node].count) {
+                                       mini[j] = leaf--;
+                               }
+                               else {
+                                       mini[j] = node++;
+                               }
+                       }
+                       hs_tree[i].left = mini[0];
+                       hs_tree[i].right = mini[1];
+                       hs_tree[i].count = hs_tree[mini[0]].count + hs_tree[mini[1]].count;
+                       hs_tree[mini[0]].parent = i;
+                       hs_tree[mini[1]].parent = i;
+                       hs_tree[mini[0]].binary = true;
+                       hs_tree[mini[1]].binary = false;
+               }
+       }
+
        void predict(int k, const std::vector<std::int32_t> &word_ids,
                                 std::vector<prediction> &results, float threshold) const
        {
@@ -596,10 +711,23 @@ public:
                        v *= inv_count;
                }
 
-               /* Compute output scores */
+               if (args.loss == LOSS_HS && !hs_tree.empty()) {
+                       predict_hs(k, hidden.data(), dim, nlabels, threshold, results);
+               }
+               else {
+                       predict_softmax(k, hidden.data(), dim, nlabels, threshold, results);
+               }
+       }
+
+private:
+       /* Flat softmax prediction */
+       void predict_softmax(int k, const float *hidden, std::int32_t dim,
+                                                std::int32_t nlabels, float threshold,
+                                                std::vector<prediction> &results) const
+       {
                std::vector<float> scores(nlabels);
                for (std::int32_t i = 0; i < nlabels; i++) {
-                       scores[i] = output_matrix->dot_row(hidden.data(), i, dim);
+                       scores[i] = output_matrix->dot_row(hidden, i, dim);
                }
 
                /* Softmax (numerically stable) */
@@ -615,11 +743,78 @@ public:
                        }
                }
 
-               /* Top-k selection */
+               collect_topk(k, scores, threshold, results);
+       }
+
+       /* Hierarchical softmax prediction via DFS on Huffman tree.
+        * Matches FastText's HierarchicalSoftmaxLoss::predict(). */
+       void predict_hs(int k, const float *hidden, std::int32_t dim,
+                                       std::int32_t osz, float threshold,
+                                       std::vector<prediction> &results) const
+       {
+               using pair_t = std::pair<float, std::int32_t>;
+               auto cmp = [](const pair_t &a, const pair_t &b) { return a.first > b.first; };
+               std::vector<pair_t> heap;
+
+               float log_threshold = std::log(threshold + 1e-5f);
+               dfs_predict(k, log_threshold, 2 * osz - 2, 0.0f, hidden, dim, osz, heap, cmp);
+
+               /* Convert log-probs to probabilities and build results */
+               results.reserve(heap.size());
+               for (auto &[score, idx]: heap) {
+                       auto label = dict.get_label(idx);
+                       results.push_back({std::exp(score), std::string(label)});
+               }
+
+               /* Sort descending by probability */
+               std::sort(results.begin(), results.end(),
+                                 [](const prediction &a, const prediction &b) { return a.prob > b.prob; });
+       }
+
+       /* DFS on the Huffman tree. At each internal node, compute sigmoid
+        * and recurse left (1-f) and right (f). Leaves are labels. */
+       template<typename Cmp>
+       void dfs_predict(int k, float log_threshold, std::int32_t node, float score,
+                                        const float *hidden, std::int32_t dim, std::int32_t osz,
+                                        std::vector<std::pair<float, std::int32_t>> &heap,
+                                        Cmp &cmp) const
+       {
+               if (score < log_threshold) return;
+               if (node < 0 || node >= static_cast<std::int32_t>(hs_tree.size())) return;
+
+               if (node < osz) {
+                       /* Leaf node = label */
+                       if (static_cast<int>(heap.size()) == k && score < heap.front().first) {
+                               return;
+                       }
+                       heap.push_back({score, node});
+                       std::push_heap(heap.begin(), heap.end(), cmp);
+                       if (static_cast<int>(heap.size()) > k) {
+                               std::pop_heap(heap.begin(), heap.end(), cmp);
+                               heap.pop_back();
+                       }
+                       return;
+               }
+
+               /* Internal node: sigmoid of dot product */
+               float f = output_matrix->dot_row(hidden, node - osz, dim);
+               f = 1.0f / (1.0f + std::exp(-f));
+
+               dfs_predict(k, log_threshold, hs_tree[node].left,
+                                       score + std::log(1.0f - f + 1e-5f), hidden, dim, osz, heap, cmp);
+               dfs_predict(k, log_threshold, hs_tree[node].right,
+                                       score + std::log(f + 1e-5f), hidden, dim, osz, heap, cmp);
+       }
+
+       /* Top-k selection from a scores vector, collecting results */
+       void collect_topk(int k, const std::vector<float> &scores,
+                                         float threshold, std::vector<prediction> &results) const
+       {
                using pair_t = std::pair<float, std::int32_t>;
                auto cmp = [](const pair_t &a, const pair_t &b) { return a.first > b.first; };
                std::priority_queue<pair_t, std::vector<pair_t>, decltype(cmp)> heap(cmp);
 
+               auto nlabels = static_cast<std::int32_t>(scores.size());
                for (std::int32_t i = 0; i < nlabels; i++) {
                        if (scores[i] < threshold) continue;
 
@@ -644,6 +839,7 @@ public:
                std::reverse(results.begin(), results.end());
        }
 
+public:
        void get_word_vector(std::vector<float> &vec, std::string_view word) const
        {
                auto dim = args.dim;
@@ -755,7 +951,7 @@ static auto load_quant_matrix(binary_reader &reader)
                return nullptr;
        }
 
-       return std::make_unique<quant_matrix>(m, n,
+       return std::make_unique<quant_matrix>(m, n, qnorm,
                                                                                  std::move(codes),
                                                                                  std::move(norm_codes),
                                                                                  std::move(pq), std::move(npq));
@@ -889,6 +1085,11 @@ auto fasttext_model::load(const char *path) -> tl::expected<fasttext_model, rspa
                                EINVAL));
        }
 
+       /* Build Huffman tree for hierarchical softmax models */
+       if (args.loss == LOSS_HS) {
+               impl->build_hs_tree();
+       }
+
        /* Store the mmap to keep it alive */
        impl->mmap_file.emplace(std::move(*mmap_result));
 
index 6a92e9871d903c79087c352d84db610006701123..f44e85ed096c2659ea88b501b004f2f2f070d419 100644 (file)
@@ -47,6 +47,7 @@ static int lua_fasttext_load(lua_State *L);
 static int lua_fasttext_model_get_dimension(lua_State *L);
 static int lua_fasttext_model_get_sentence_vector(lua_State *L);
 static int lua_fasttext_model_get_word_vector(lua_State *L);
+static int lua_fasttext_model_predict(lua_State *L);
 static int lua_fasttext_model_dtor(lua_State *L);
 static int lua_fasttext_model_is_loaded(lua_State *L);
 
@@ -61,6 +62,7 @@ static const struct luaL_reg fasttextlib_m[] = {
        {"get_dimension", lua_fasttext_model_get_dimension},
        {"get_sentence_vector", lua_fasttext_model_get_sentence_vector},
        {"get_word_vector", lua_fasttext_model_get_word_vector},
+       {"predict", lua_fasttext_model_predict},
        {"is_loaded", lua_fasttext_model_is_loaded},
        {"__gc", lua_fasttext_model_dtor},
        {"__tostring", rspamd_lua_class_tostring},
@@ -268,6 +270,64 @@ lua_fasttext_model_get_sentence_vector(lua_State *L)
        return 1;
 }
 
+/***
+ * @method model:predict(words, k)
+ * Run supervised classification on a table of words.
+ * Each word is converted to input matrix row IDs internally.
+ * @param {table} words table of word strings
+ * @param {number} k number of top predictions to return (default 1)
+ * @return {table} array of {label=string, prob=number} tables, sorted by probability descending
+ */
+static int
+lua_fasttext_model_predict(lua_State *L)
+{
+       auto *model = lua_check_fasttext_model(L, 1);
+
+       if (!model || !model->loaded) {
+               lua_pushnil(L);
+               return 1;
+       }
+
+       luaL_argcheck(L, lua_istable(L, 2), 2, "'table' of words expected");
+       int k = luaL_optinteger(L, 3, 1);
+
+       /* Convert words to input matrix row IDs */
+       std::vector<std::int32_t> word_ids;
+       auto nwords = rspamd_lua_table_size(L, 2);
+
+       for (auto i = 1; i <= nwords; i++) {
+               lua_rawgeti(L, 2, i);
+               if (lua_isstring(L, -1)) {
+                       std::size_t len;
+                       const char *w = lua_tolstring(L, -1, &len);
+                       if (len > 0) {
+                               model->model->word2vec(std::string_view{w, len}, word_ids);
+                       }
+               }
+               lua_pop(L, 1);
+       }
+
+       if (word_ids.empty()) {
+               lua_newtable(L);
+               return 1;
+       }
+
+       std::vector<rspamd::fasttext::prediction> preds;
+       model->model->predict(k, word_ids, preds, 0.0f);
+
+       lua_createtable(L, static_cast<int>(preds.size()), 0);
+       for (std::size_t i = 0; i < preds.size(); i++) {
+               lua_createtable(L, 0, 2);
+               lua_pushstring(L, preds[i].label.c_str());
+               lua_setfield(L, -2, "label");
+               lua_pushnumber(L, static_cast<double>(preds[i].prob));
+               lua_setfield(L, -2, "prob");
+               lua_rawseti(L, -2, static_cast<int>(i + 1));
+       }
+
+       return 1;
+}
+
 static int
 lua_fasttext_model_dtor(lua_State *L)
 {