return true;
}
- void add_code(const std::uint8_t *codes, float *vec, std::int32_t dim) const
+ void add_code(const std::uint8_t *codes, float *vec, std::int32_t dim, float alpha) const
{
if (centroids_.empty()) return;
std::int32_t offset = 0;
for (std::int32_t d = 0; d < sub_dim; d++) {
if (centroid_base + d < centroids_.size() && offset + d < dim) {
- vec[offset + d] += centroids_[centroid_base + d];
+ vec[offset + d] += alpha * centroids_[centroid_base + d];
}
}
offset += sub_dim;
}
}
- auto dot_code(const std::uint8_t *codes, const float *vec, std::int32_t dim) const -> float
+ auto dot_code(const std::uint8_t *codes, const float *vec, std::int32_t dim, float alpha) const -> float
{
if (centroids_.empty()) return 0.0f;
float result = 0.0f;
}
offset += sub_dim;
}
- return result;
+ return result * alpha;
}
auto get_nsubq() const -> std::int32_t
return nsubq_;
}
+ /* Look up a single centroid value (used for norm decoding via NPQ) */
+ auto get_centroid_value(std::int32_t sq, std::uint8_t code) const -> float
+ {
+ auto offset = get_centroid_offset(sq, static_cast<std::size_t>(code));
+ if (offset < centroids_.size()) {
+ return centroids_[offset];
+ }
+ return 1.0f;
+ }
+
private:
/* Matches FastText's get_centroids: last sub-quantizer uses lastdsub stride */
auto get_centroid_offset(std::int32_t sq, std::size_t centroid_idx) const -> std::size_t
/* Quantized matrix: codes + product quantizer */
class quant_matrix final : public matrix_base {
public:
- quant_matrix(std::int64_t m, std::int64_t n,
+ quant_matrix(std::int64_t m, std::int64_t n, bool qnorm,
std::vector<std::uint8_t> &&codes,
std::vector<std::uint8_t> &&norm_codes,
product_quantizer &&pq,
product_quantizer &&npq)
- : m_(m), n_(n),
+ : m_(m), n_(n), qnorm_(qnorm),
codes_(std::move(codes)),
norm_codes_(std::move(norm_codes)),
pq_(std::move(pq)),
if (nsubq <= 0) return;
auto offset = static_cast<std::size_t>(row) * nsubq;
if (offset + nsubq > codes_.size()) return;
- pq_.add_code(codes_.data() + offset, vec, dim);
+ auto norm = get_norm(row);
+ pq_.add_code(codes_.data() + offset, vec, dim, norm);
}
auto dot_row(const float *vec, std::int32_t row, std::int32_t dim) const -> float override
if (nsubq <= 0) return 0.0f;
auto offset = static_cast<std::size_t>(row) * nsubq;
if (offset + nsubq > codes_.size()) return 0.0f;
- return pq_.dot_code(codes_.data() + offset, vec, dim);
+ auto norm = get_norm(row);
+ return pq_.dot_code(codes_.data() + offset, vec, dim, norm);
}
auto rows() const -> std::int64_t override
}
private:
+ auto get_norm(std::int32_t row) const -> float
+ {
+ if (qnorm_ && row >= 0 && static_cast<std::size_t>(row) < norm_codes_.size()) {
+ return npq_.get_centroid_value(0, norm_codes_[row]);
+ }
+ return 1.0f;
+ }
+
std::int64_t m_;
std::int64_t n_;
+ bool qnorm_;
std::vector<std::uint8_t> codes_;
std::vector<std::uint8_t> norm_codes_;
product_quantizer pq_;
if (len == 1 && (i == 0 || i + 1 == ncp)) continue;
auto ngram = word.substr(positions[i], positions[i + len] - positions[i]);
- auto h = fnv_hash(ngram) % bucket_;
- ngrams.push_back(nwords_ + static_cast<std::int32_t>(h));
+ auto h = static_cast<std::int32_t>(fnv_hash(ngram) % bucket_);
+ push_hash(ngrams, h);
}
}
}
return bucket_;
}
+ auto get_label_counts() const -> std::vector<std::int64_t>
+ {
+ std::vector<std::int64_t> counts;
+ counts.reserve(nlabels_);
+ for (std::int32_t i = nwords_; i < nwords_ + nlabels_; i++) {
+ if (i < static_cast<std::int32_t>(entries_.size())) {
+ counts.push_back(entries_[i].count);
+ }
+ }
+ return counts;
+ }
+
private:
+ /* Matches FastText's Dictionary::pushHash():
+ * - pruneidx_size_ < 0 (e.g. -1): no pruning, push directly
+ * - pruneidx_size_ == 0: all pruned, drop everything
+ * - pruneidx_size_ > 0: remap through pruneidx_ or drop if absent */
+ void push_hash(std::vector<std::int32_t> &ngrams, std::int32_t id) const
+ {
+ if (pruneidx_size_ == 0 || id < 0) {
+ return;
+ }
+ if (pruneidx_size_ > 0) {
+ auto it = pruneidx_.find(id);
+ if (it != pruneidx_.end()) {
+ id = it->second;
+ }
+ else {
+ return;
+ }
+ }
+ ngrams.push_back(nwords_ + id);
+ }
+
std::int32_t nwords_ = 0;
std::int32_t nlabels_ = 0;
std::int64_t ntokens_ = 0;
};
+/* Loss function types (matches FastText's enum) */
+constexpr std::int32_t LOSS_HS = 1;
+constexpr std::int32_t LOSS_NS = 2;
+constexpr std::int32_t LOSS_SOFTMAX = 3;
+
+/* Huffman tree node for hierarchical softmax */
+struct hs_node {
+ std::int32_t parent = -1;
+ std::int32_t left = -1;
+ std::int32_t right = -1;
+ std::int64_t count = 0;
+ bool binary = false;
+};
+
/* --- Model implementation (pimpl) --- */
class fasttext_model_impl {
public:
std::unique_ptr<matrix_base> output_matrix;
/* Keep the mmap alive for the lifetime of the model */
std::optional<rspamd::util::raii_mmaped_file> mmap_file;
+ /* Huffman tree for hierarchical softmax */
+ std::vector<hs_node> hs_tree;
void word2vec(std::string_view word, std::vector<std::int32_t> &ngrams) const
{
}
}
+ /* Build Huffman tree for hierarchical softmax from label counts.
+ * Matches FastText's HierarchicalSoftmaxLoss::buildTree(). */
+ void build_hs_tree()
+ {
+ auto counts = dict.get_label_counts();
+ auto osz = static_cast<std::int32_t>(counts.size());
+ if (osz <= 0) return;
+
+ hs_tree.resize(2 * osz - 1);
+ for (auto &node: hs_tree) {
+ node.parent = -1;
+ node.left = -1;
+ node.right = -1;
+ node.count = 1e15;
+ node.binary = false;
+ }
+
+ /* Leaves get their actual counts */
+ for (std::int32_t i = 0; i < osz; i++) {
+ hs_tree[i].count = counts[i];
+ }
+
+ /* Build tree bottom-up */
+ std::int32_t leaf = osz - 1;
+ std::int32_t node = osz;
+ for (std::int32_t i = osz; i < 2 * osz - 1; i++) {
+ std::int32_t mini[2] = {0, 0};
+ for (int j = 0; j < 2; j++) {
+ if (leaf >= 0 && hs_tree[leaf].count < hs_tree[node].count) {
+ mini[j] = leaf--;
+ }
+ else {
+ mini[j] = node++;
+ }
+ }
+ hs_tree[i].left = mini[0];
+ hs_tree[i].right = mini[1];
+ hs_tree[i].count = hs_tree[mini[0]].count + hs_tree[mini[1]].count;
+ hs_tree[mini[0]].parent = i;
+ hs_tree[mini[1]].parent = i;
+ hs_tree[mini[0]].binary = true;
+ hs_tree[mini[1]].binary = false;
+ }
+ }
+
void predict(int k, const std::vector<std::int32_t> &word_ids,
std::vector<prediction> &results, float threshold) const
{
v *= inv_count;
}
- /* Compute output scores */
+ if (args.loss == LOSS_HS && !hs_tree.empty()) {
+ predict_hs(k, hidden.data(), dim, nlabels, threshold, results);
+ }
+ else {
+ predict_softmax(k, hidden.data(), dim, nlabels, threshold, results);
+ }
+ }
+
+private:
+ /* Flat softmax prediction */
+ void predict_softmax(int k, const float *hidden, std::int32_t dim,
+ std::int32_t nlabels, float threshold,
+ std::vector<prediction> &results) const
+ {
std::vector<float> scores(nlabels);
for (std::int32_t i = 0; i < nlabels; i++) {
- scores[i] = output_matrix->dot_row(hidden.data(), i, dim);
+ scores[i] = output_matrix->dot_row(hidden, i, dim);
}
/* Softmax (numerically stable) */
}
}
- /* Top-k selection */
+ collect_topk(k, scores, threshold, results);
+ }
+
+ /* Hierarchical softmax prediction via DFS on Huffman tree.
+ * Matches FastText's HierarchicalSoftmaxLoss::predict(). */
+ void predict_hs(int k, const float *hidden, std::int32_t dim,
+ std::int32_t osz, float threshold,
+ std::vector<prediction> &results) const
+ {
+ using pair_t = std::pair<float, std::int32_t>;
+ auto cmp = [](const pair_t &a, const pair_t &b) { return a.first > b.first; };
+ std::vector<pair_t> heap;
+
+ float log_threshold = std::log(threshold + 1e-5f);
+ dfs_predict(k, log_threshold, 2 * osz - 2, 0.0f, hidden, dim, osz, heap, cmp);
+
+ /* Convert log-probs to probabilities and build results */
+ results.reserve(heap.size());
+ for (auto &[score, idx]: heap) {
+ auto label = dict.get_label(idx);
+ results.push_back({std::exp(score), std::string(label)});
+ }
+
+ /* Sort descending by probability */
+ std::sort(results.begin(), results.end(),
+ [](const prediction &a, const prediction &b) { return a.prob > b.prob; });
+ }
+
+ /* DFS on the Huffman tree. At each internal node, compute sigmoid
+ * and recurse left (1-f) and right (f). Leaves are labels. */
+ template<typename Cmp>
+ void dfs_predict(int k, float log_threshold, std::int32_t node, float score,
+ const float *hidden, std::int32_t dim, std::int32_t osz,
+ std::vector<std::pair<float, std::int32_t>> &heap,
+ Cmp &cmp) const
+ {
+ if (score < log_threshold) return;
+ if (node < 0 || node >= static_cast<std::int32_t>(hs_tree.size())) return;
+
+ if (node < osz) {
+ /* Leaf node = label */
+ if (static_cast<int>(heap.size()) == k && score < heap.front().first) {
+ return;
+ }
+ heap.push_back({score, node});
+ std::push_heap(heap.begin(), heap.end(), cmp);
+ if (static_cast<int>(heap.size()) > k) {
+ std::pop_heap(heap.begin(), heap.end(), cmp);
+ heap.pop_back();
+ }
+ return;
+ }
+
+ /* Internal node: sigmoid of dot product */
+ float f = output_matrix->dot_row(hidden, node - osz, dim);
+ f = 1.0f / (1.0f + std::exp(-f));
+
+ dfs_predict(k, log_threshold, hs_tree[node].left,
+ score + std::log(1.0f - f + 1e-5f), hidden, dim, osz, heap, cmp);
+ dfs_predict(k, log_threshold, hs_tree[node].right,
+ score + std::log(f + 1e-5f), hidden, dim, osz, heap, cmp);
+ }
+
+ /* Top-k selection from a scores vector, collecting results */
+ void collect_topk(int k, const std::vector<float> &scores,
+ float threshold, std::vector<prediction> &results) const
+ {
using pair_t = std::pair<float, std::int32_t>;
auto cmp = [](const pair_t &a, const pair_t &b) { return a.first > b.first; };
std::priority_queue<pair_t, std::vector<pair_t>, decltype(cmp)> heap(cmp);
+ auto nlabels = static_cast<std::int32_t>(scores.size());
for (std::int32_t i = 0; i < nlabels; i++) {
if (scores[i] < threshold) continue;
std::reverse(results.begin(), results.end());
}
+public:
void get_word_vector(std::vector<float> &vec, std::string_view word) const
{
auto dim = args.dim;
return nullptr;
}
- return std::make_unique<quant_matrix>(m, n,
+ return std::make_unique<quant_matrix>(m, n, qnorm,
std::move(codes),
std::move(norm_codes),
std::move(pq), std::move(npq));
EINVAL));
}
+ /* Build Huffman tree for hierarchical softmax models */
+ if (args.loss == LOSS_HS) {
+ impl->build_hs_tree();
+ }
+
/* Store the mmap to keep it alive */
impl->mmap_file.emplace(std::move(*mmap_result));
static int lua_fasttext_model_get_dimension(lua_State *L);
static int lua_fasttext_model_get_sentence_vector(lua_State *L);
static int lua_fasttext_model_get_word_vector(lua_State *L);
+static int lua_fasttext_model_predict(lua_State *L);
static int lua_fasttext_model_dtor(lua_State *L);
static int lua_fasttext_model_is_loaded(lua_State *L);
{"get_dimension", lua_fasttext_model_get_dimension},
{"get_sentence_vector", lua_fasttext_model_get_sentence_vector},
{"get_word_vector", lua_fasttext_model_get_word_vector},
+ {"predict", lua_fasttext_model_predict},
{"is_loaded", lua_fasttext_model_is_loaded},
{"__gc", lua_fasttext_model_dtor},
{"__tostring", rspamd_lua_class_tostring},
return 1;
}
+/***
+ * @method model:predict(words, k)
+ * Run supervised classification on a table of words.
+ * Each word is converted to input matrix row IDs internally.
+ * @param {table} words table of word strings
+ * @param {number} k number of top predictions to return (default 1)
+ * @return {table} array of {label=string, prob=number} tables, sorted by probability descending
+ */
+static int
+lua_fasttext_model_predict(lua_State *L)
+{
+ auto *model = lua_check_fasttext_model(L, 1);
+
+ if (!model || !model->loaded) {
+ lua_pushnil(L);
+ return 1;
+ }
+
+ luaL_argcheck(L, lua_istable(L, 2), 2, "'table' of words expected");
+ int k = luaL_optinteger(L, 3, 1);
+
+ /* Convert words to input matrix row IDs */
+ std::vector<std::int32_t> word_ids;
+ auto nwords = rspamd_lua_table_size(L, 2);
+
+ for (auto i = 1; i <= nwords; i++) {
+ lua_rawgeti(L, 2, i);
+ if (lua_isstring(L, -1)) {
+ std::size_t len;
+ const char *w = lua_tolstring(L, -1, &len);
+ if (len > 0) {
+ model->model->word2vec(std::string_view{w, len}, word_ids);
+ }
+ }
+ lua_pop(L, 1);
+ }
+
+ if (word_ids.empty()) {
+ lua_newtable(L);
+ return 1;
+ }
+
+ std::vector<rspamd::fasttext::prediction> preds;
+ model->model->predict(k, word_ids, preds, 0.0f);
+
+ lua_createtable(L, static_cast<int>(preds.size()), 0);
+ for (std::size_t i = 0; i < preds.size(); i++) {
+ lua_createtable(L, 0, 2);
+ lua_pushstring(L, preds[i].label.c_str());
+ lua_setfield(L, -2, "label");
+ lua_pushnumber(L, static_cast<double>(preds[i].prob));
+ lua_setfield(L, -2, "prob");
+ lua_rawseti(L, -2, static_cast<int>(i + 1));
+ }
+
+ return 1;
+}
+
static int
lua_fasttext_model_dtor(lua_State *L)
{