From: Vsevolod Stakhov Date: Sat, 21 Feb 2026 18:54:10 +0000 (+0000) Subject: [Fix] Fasttext shim: fix quantized model support and add hierarchical softmax X-Git-Tag: 4.0.0~78^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8ca2d4354be516a9e60e9dc06b0fe0cc2bce72b7;p=thirdparty%2Frspamd.git [Fix] Fasttext shim: fix quantized model support and add hierarchical softmax - Add pruneidx support in dictionary subword computation: pruned models (like lid.176.ftz) remap subword hash IDs through the prune index table, matching FastText's pushHash() behavior - Apply per-row norm from quantized norm PQ (NPQ) when decoding QMatrix rows, fixing vector magnitude for quantized models - Implement hierarchical softmax prediction via Huffman tree DFS, matching FastText's HierarchicalSoftmaxLoss for models with loss=hs - Add predict() method to Lua rspamd_fasttext bindings --- diff --git a/src/libserver/fasttext/fasttext_shim.cxx b/src/libserver/fasttext/fasttext_shim.cxx index 74cac134e1..380085e427 100644 --- a/src/libserver/fasttext/fasttext_shim.cxx +++ b/src/libserver/fasttext/fasttext_shim.cxx @@ -220,7 +220,7 @@ public: return true; } - void add_code(const std::uint8_t *codes, float *vec, std::int32_t dim) const + void add_code(const std::uint8_t *codes, float *vec, std::int32_t dim, float alpha) const { if (centroids_.empty()) return; std::int32_t offset = 0; @@ -232,14 +232,14 @@ public: for (std::int32_t d = 0; d < sub_dim; d++) { if (centroid_base + d < centroids_.size() && offset + d < dim) { - vec[offset + d] += centroids_[centroid_base + d]; + vec[offset + d] += alpha * centroids_[centroid_base + d]; } } offset += sub_dim; } } - auto dot_code(const std::uint8_t *codes, const float *vec, std::int32_t dim) const -> float + auto dot_code(const std::uint8_t *codes, const float *vec, std::int32_t dim, float alpha) const -> float { if (centroids_.empty()) return 0.0f; float result = 0.0f; @@ -257,7 +257,7 @@ public: } offset += sub_dim; } - return result; + return result * alpha; } auto get_nsubq() const -> std::int32_t @@ -265,6 +265,16 @@ public: return nsubq_; } + /* Look up a single centroid value (used for norm decoding via NPQ) */ + auto get_centroid_value(std::int32_t sq, std::uint8_t code) const -> float + { + auto offset = get_centroid_offset(sq, static_cast(code)); + if (offset < centroids_.size()) { + return centroids_[offset]; + } + return 1.0f; + } + private: /* Matches FastText's get_centroids: last sub-quantizer uses lastdsub stride */ auto get_centroid_offset(std::int32_t sq, std::size_t centroid_idx) const -> std::size_t @@ -350,12 +360,12 @@ private: /* Quantized matrix: codes + product quantizer */ class quant_matrix final : public matrix_base { public: - quant_matrix(std::int64_t m, std::int64_t n, + quant_matrix(std::int64_t m, std::int64_t n, bool qnorm, std::vector &&codes, std::vector &&norm_codes, product_quantizer &&pq, product_quantizer &&npq) - : m_(m), n_(n), + : m_(m), n_(n), qnorm_(qnorm), codes_(std::move(codes)), norm_codes_(std::move(norm_codes)), pq_(std::move(pq)), @@ -370,7 +380,8 @@ public: if (nsubq <= 0) return; auto offset = static_cast(row) * nsubq; if (offset + nsubq > codes_.size()) return; - pq_.add_code(codes_.data() + offset, vec, dim); + auto norm = get_norm(row); + pq_.add_code(codes_.data() + offset, vec, dim, norm); } auto dot_row(const float *vec, std::int32_t row, std::int32_t dim) const -> float override @@ -380,7 +391,8 @@ public: if (nsubq <= 0) return 0.0f; auto offset = static_cast(row) * nsubq; if (offset + nsubq > codes_.size()) return 0.0f; - return pq_.dot_code(codes_.data() + offset, vec, dim); + auto norm = get_norm(row); + return pq_.dot_code(codes_.data() + offset, vec, dim, norm); } auto rows() const -> std::int64_t override @@ -393,8 +405,17 @@ public: } private: + auto get_norm(std::int32_t row) const -> float + { + if (qnorm_ && row >= 0 && static_cast(row) < norm_codes_.size()) { + return npq_.get_centroid_value(0, norm_codes_[row]); + } + return 1.0f; + } + std::int64_t m_; std::int64_t n_; + bool qnorm_; std::vector codes_; std::vector norm_codes_; product_quantizer pq_; @@ -494,8 +515,8 @@ public: if (len == 1 && (i == 0 || i + 1 == ncp)) continue; auto ngram = word.substr(positions[i], positions[i + len] - positions[i]); - auto h = fnv_hash(ngram) % bucket_; - ngrams.push_back(nwords_ + static_cast(h)); + auto h = static_cast(fnv_hash(ngram) % bucket_); + push_hash(ngrams, h); } } } @@ -522,7 +543,40 @@ public: return bucket_; } + auto get_label_counts() const -> std::vector + { + std::vector counts; + counts.reserve(nlabels_); + for (std::int32_t i = nwords_; i < nwords_ + nlabels_; i++) { + if (i < static_cast(entries_.size())) { + counts.push_back(entries_[i].count); + } + } + return counts; + } + private: + /* Matches FastText's Dictionary::pushHash(): + * - pruneidx_size_ < 0 (e.g. -1): no pruning, push directly + * - pruneidx_size_ == 0: all pruned, drop everything + * - pruneidx_size_ > 0: remap through pruneidx_ or drop if absent */ + void push_hash(std::vector &ngrams, std::int32_t id) const + { + if (pruneidx_size_ == 0 || id < 0) { + return; + } + if (pruneidx_size_ > 0) { + auto it = pruneidx_.find(id); + if (it != pruneidx_.end()) { + id = it->second; + } + else { + return; + } + } + ngrams.push_back(nwords_ + id); + } + std::int32_t nwords_ = 0; std::int32_t nlabels_ = 0; std::int64_t ntokens_ = 0; @@ -538,6 +592,20 @@ private: }; +/* Loss function types (matches FastText's enum) */ +constexpr std::int32_t LOSS_HS = 1; +constexpr std::int32_t LOSS_NS = 2; +constexpr std::int32_t LOSS_SOFTMAX = 3; + +/* Huffman tree node for hierarchical softmax */ +struct hs_node { + std::int32_t parent = -1; + std::int32_t left = -1; + std::int32_t right = -1; + std::int64_t count = 0; + bool binary = false; +}; + /* --- Model implementation (pimpl) --- */ class fasttext_model_impl { public: @@ -547,6 +615,8 @@ public: std::unique_ptr output_matrix; /* Keep the mmap alive for the lifetime of the model */ std::optional mmap_file; + /* Huffman tree for hierarchical softmax */ + std::vector hs_tree; void word2vec(std::string_view word, std::vector &ngrams) const { @@ -573,6 +643,51 @@ public: } } + /* Build Huffman tree for hierarchical softmax from label counts. + * Matches FastText's HierarchicalSoftmaxLoss::buildTree(). */ + void build_hs_tree() + { + auto counts = dict.get_label_counts(); + auto osz = static_cast(counts.size()); + if (osz <= 0) return; + + hs_tree.resize(2 * osz - 1); + for (auto &node: hs_tree) { + node.parent = -1; + node.left = -1; + node.right = -1; + node.count = 1e15; + node.binary = false; + } + + /* Leaves get their actual counts */ + for (std::int32_t i = 0; i < osz; i++) { + hs_tree[i].count = counts[i]; + } + + /* Build tree bottom-up */ + std::int32_t leaf = osz - 1; + std::int32_t node = osz; + for (std::int32_t i = osz; i < 2 * osz - 1; i++) { + std::int32_t mini[2] = {0, 0}; + for (int j = 0; j < 2; j++) { + if (leaf >= 0 && hs_tree[leaf].count < hs_tree[node].count) { + mini[j] = leaf--; + } + else { + mini[j] = node++; + } + } + hs_tree[i].left = mini[0]; + hs_tree[i].right = mini[1]; + hs_tree[i].count = hs_tree[mini[0]].count + hs_tree[mini[1]].count; + hs_tree[mini[0]].parent = i; + hs_tree[mini[1]].parent = i; + hs_tree[mini[0]].binary = true; + hs_tree[mini[1]].binary = false; + } + } + void predict(int k, const std::vector &word_ids, std::vector &results, float threshold) const { @@ -596,10 +711,23 @@ public: v *= inv_count; } - /* Compute output scores */ + if (args.loss == LOSS_HS && !hs_tree.empty()) { + predict_hs(k, hidden.data(), dim, nlabels, threshold, results); + } + else { + predict_softmax(k, hidden.data(), dim, nlabels, threshold, results); + } + } + +private: + /* Flat softmax prediction */ + void predict_softmax(int k, const float *hidden, std::int32_t dim, + std::int32_t nlabels, float threshold, + std::vector &results) const + { std::vector scores(nlabels); for (std::int32_t i = 0; i < nlabels; i++) { - scores[i] = output_matrix->dot_row(hidden.data(), i, dim); + scores[i] = output_matrix->dot_row(hidden, i, dim); } /* Softmax (numerically stable) */ @@ -615,11 +743,78 @@ public: } } - /* Top-k selection */ + collect_topk(k, scores, threshold, results); + } + + /* Hierarchical softmax prediction via DFS on Huffman tree. + * Matches FastText's HierarchicalSoftmaxLoss::predict(). */ + void predict_hs(int k, const float *hidden, std::int32_t dim, + std::int32_t osz, float threshold, + std::vector &results) const + { + using pair_t = std::pair; + auto cmp = [](const pair_t &a, const pair_t &b) { return a.first > b.first; }; + std::vector heap; + + float log_threshold = std::log(threshold + 1e-5f); + dfs_predict(k, log_threshold, 2 * osz - 2, 0.0f, hidden, dim, osz, heap, cmp); + + /* Convert log-probs to probabilities and build results */ + results.reserve(heap.size()); + for (auto &[score, idx]: heap) { + auto label = dict.get_label(idx); + results.push_back({std::exp(score), std::string(label)}); + } + + /* Sort descending by probability */ + std::sort(results.begin(), results.end(), + [](const prediction &a, const prediction &b) { return a.prob > b.prob; }); + } + + /* DFS on the Huffman tree. At each internal node, compute sigmoid + * and recurse left (1-f) and right (f). Leaves are labels. */ + template + void dfs_predict(int k, float log_threshold, std::int32_t node, float score, + const float *hidden, std::int32_t dim, std::int32_t osz, + std::vector> &heap, + Cmp &cmp) const + { + if (score < log_threshold) return; + if (node < 0 || node >= static_cast(hs_tree.size())) return; + + if (node < osz) { + /* Leaf node = label */ + if (static_cast(heap.size()) == k && score < heap.front().first) { + return; + } + heap.push_back({score, node}); + std::push_heap(heap.begin(), heap.end(), cmp); + if (static_cast(heap.size()) > k) { + std::pop_heap(heap.begin(), heap.end(), cmp); + heap.pop_back(); + } + return; + } + + /* Internal node: sigmoid of dot product */ + float f = output_matrix->dot_row(hidden, node - osz, dim); + f = 1.0f / (1.0f + std::exp(-f)); + + dfs_predict(k, log_threshold, hs_tree[node].left, + score + std::log(1.0f - f + 1e-5f), hidden, dim, osz, heap, cmp); + dfs_predict(k, log_threshold, hs_tree[node].right, + score + std::log(f + 1e-5f), hidden, dim, osz, heap, cmp); + } + + /* Top-k selection from a scores vector, collecting results */ + void collect_topk(int k, const std::vector &scores, + float threshold, std::vector &results) const + { using pair_t = std::pair; auto cmp = [](const pair_t &a, const pair_t &b) { return a.first > b.first; }; std::priority_queue, decltype(cmp)> heap(cmp); + auto nlabels = static_cast(scores.size()); for (std::int32_t i = 0; i < nlabels; i++) { if (scores[i] < threshold) continue; @@ -644,6 +839,7 @@ public: std::reverse(results.begin(), results.end()); } +public: void get_word_vector(std::vector &vec, std::string_view word) const { auto dim = args.dim; @@ -755,7 +951,7 @@ static auto load_quant_matrix(binary_reader &reader) return nullptr; } - return std::make_unique(m, n, + return std::make_unique(m, n, qnorm, std::move(codes), std::move(norm_codes), std::move(pq), std::move(npq)); @@ -889,6 +1085,11 @@ auto fasttext_model::load(const char *path) -> tl::expectedbuild_hs_tree(); + } + /* Store the mmap to keep it alive */ impl->mmap_file.emplace(std::move(*mmap_result)); diff --git a/src/lua/lua_fasttext.cxx b/src/lua/lua_fasttext.cxx index 6a92e9871d..f44e85ed09 100644 --- a/src/lua/lua_fasttext.cxx +++ b/src/lua/lua_fasttext.cxx @@ -47,6 +47,7 @@ static int lua_fasttext_load(lua_State *L); static int lua_fasttext_model_get_dimension(lua_State *L); static int lua_fasttext_model_get_sentence_vector(lua_State *L); static int lua_fasttext_model_get_word_vector(lua_State *L); +static int lua_fasttext_model_predict(lua_State *L); static int lua_fasttext_model_dtor(lua_State *L); static int lua_fasttext_model_is_loaded(lua_State *L); @@ -61,6 +62,7 @@ static const struct luaL_reg fasttextlib_m[] = { {"get_dimension", lua_fasttext_model_get_dimension}, {"get_sentence_vector", lua_fasttext_model_get_sentence_vector}, {"get_word_vector", lua_fasttext_model_get_word_vector}, + {"predict", lua_fasttext_model_predict}, {"is_loaded", lua_fasttext_model_is_loaded}, {"__gc", lua_fasttext_model_dtor}, {"__tostring", rspamd_lua_class_tostring}, @@ -268,6 +270,64 @@ lua_fasttext_model_get_sentence_vector(lua_State *L) return 1; } +/*** + * @method model:predict(words, k) + * Run supervised classification on a table of words. + * Each word is converted to input matrix row IDs internally. + * @param {table} words table of word strings + * @param {number} k number of top predictions to return (default 1) + * @return {table} array of {label=string, prob=number} tables, sorted by probability descending + */ +static int +lua_fasttext_model_predict(lua_State *L) +{ + auto *model = lua_check_fasttext_model(L, 1); + + if (!model || !model->loaded) { + lua_pushnil(L); + return 1; + } + + luaL_argcheck(L, lua_istable(L, 2), 2, "'table' of words expected"); + int k = luaL_optinteger(L, 3, 1); + + /* Convert words to input matrix row IDs */ + std::vector word_ids; + auto nwords = rspamd_lua_table_size(L, 2); + + for (auto i = 1; i <= nwords; i++) { + lua_rawgeti(L, 2, i); + if (lua_isstring(L, -1)) { + std::size_t len; + const char *w = lua_tolstring(L, -1, &len); + if (len > 0) { + model->model->word2vec(std::string_view{w, len}, word_ids); + } + } + lua_pop(L, 1); + } + + if (word_ids.empty()) { + lua_newtable(L); + return 1; + } + + std::vector preds; + model->model->predict(k, word_ids, preds, 0.0f); + + lua_createtable(L, static_cast(preds.size()), 0); + for (std::size_t i = 0; i < preds.size(); i++) { + lua_createtable(L, 0, 2); + lua_pushstring(L, preds[i].label.c_str()); + lua_setfield(L, -2, "label"); + lua_pushnumber(L, static_cast(preds[i].prob)); + lua_setfield(L, -2, "prob"); + lua_rawseti(L, -2, static_cast(i + 1)); + } + + return 1; +} + static int lua_fasttext_model_dtor(lua_State *L) {