]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Fasttext shim: fix binary format parsing and harden against corrupt models
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 21 Feb 2026 17:27:42 +0000 (17:27 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 21 Feb 2026 17:27:42 +0000 (17:27 +0000)
- Fix QMatrix load order: read codesize+codes before PQ (not after)
- Fix PQ centroid count: use dim*ksub (not nsubq*ksub*dsub)
- Fix PQ centroid addressing: match FastText's get_centroids() for last sub-quantizer
- Fix dictionary load: read size_ field before nwords/nlabels
- Fix output matrix: always read qout bool between input and output matrices
- Fix subword n-gram skip: only skip single-char BOW/EOW, not full wrapped word

Add comprehensive sanity checks for all untrusted values from model files:
- Validate dimensions, entry counts, matrix sizes against sane upper bounds
- Overflow-safe multiplication for matrix element counts
- Bounds checks on centroid, codes, and dense matrix data access
- Null pointer guards on all matrix operations
- Replace throwing .at() with bounds-checked pointer return
- Limit string reads to 1024 bytes to prevent runaway allocation
- Return nullptr/false from loaders on validation failure
- Guard Lua bindings against empty/short vectors from get_word_vector

src/libserver/fasttext/fasttext_shim.cxx
src/lua/lua_fasttext.cxx

index 6a51a6f57c6eb99a139de4c7ae040d0d14f183b4..74cac134e19dc446c7415a8f198111beeb7aca7b 100644 (file)
@@ -33,6 +33,14 @@ namespace rspamd::fasttext {
 
 namespace {
 
+/* Sanity limits for untrusted values read from model files */
+constexpr std::int32_t MAX_SANE_DIM = 4096;
+constexpr std::int32_t MAX_SANE_ENTRIES = 50'000'000;
+constexpr std::int64_t MAX_SANE_MATRIX_ROWS = 500'000'000;
+constexpr std::int32_t MAX_SANE_CODESIZE = 500'000'000;
+constexpr std::int32_t MAX_SANE_BUCKET = 100'000'000;
+constexpr std::size_t MAX_SANE_STRING = 1024;
+
 /* --- Binary reader: a cursor over memory-mapped data --- */
 /* Uses a fail-bit pattern instead of exceptions: once any read overflows,
  * the reader enters a failed state and all subsequent reads return zeroes.
@@ -110,12 +118,20 @@ public:
                        auto ch = data_[pos_++];
                        if (ch == 0) break;
                        result.push_back(static_cast<char>(ch));
+                       if (result.size() > MAX_SANE_STRING) {
+                               failed_ = true;
+                               return {};
+                       }
                }
                return result;
        }
 
        auto read_floats(std::size_t count) -> const float *
        {
+               if (count > SIZE_MAX / sizeof(float)) {
+                       failed_ = true;
+                       return nullptr;
+               }
                auto bytes = count * sizeof(float);
                if (!ensure(bytes)) return nullptr;
                auto ptr = reinterpret_cast<const float *>(data_ + pos_);
@@ -182,35 +198,41 @@ struct dict_entry {
 /* --- Product Quantizer --- */
 class product_quantizer {
 public:
-       void load(binary_reader &reader)
+       auto load(binary_reader &reader) -> bool
        {
                dim_ = reader.read_i32();
                nsubq_ = reader.read_i32();
                dsub_ = reader.read_i32();
                lastdsub_ = reader.read_i32();
 
-               auto centroid_count = static_cast<std::size_t>(nsubq_) * ksub_;
-               auto total_floats = centroid_count * dsub_;
-               /* Actually, centroids are stored as nsubq * ksub * dsub floats,
-                * but the last subquantizer may use lastdsub. FastText stores
-                * all centroids with dsub stride. */
+               if (reader.fail() || dim_ <= 0 || dim_ > MAX_SANE_DIM ||
+                       nsubq_ <= 0 || dsub_ <= 0 || lastdsub_ <= 0) {
+                       return false;
+               }
+
+               /* Centroids are stored as dim * ksub floats (not nsubq * ksub * dsub) */
+               auto total_floats = static_cast<std::size_t>(dim_) * ksub_;
                auto ptr = reader.read_floats(total_floats);
+               if (!ptr) {
+                       return false;
+               }
                centroids_.assign(ptr, ptr + total_floats);
+               return true;
        }
 
        void add_code(const std::uint8_t *codes, float *vec, std::int32_t dim) const
        {
-               float norm = 1.0f;
+               if (centroids_.empty()) return;
                std::int32_t offset = 0;
 
                for (std::int32_t sq = 0; sq < nsubq_; sq++) {
                        auto centroid_idx = static_cast<std::size_t>(codes[sq]);
                        auto sub_dim = (sq == nsubq_ - 1) ? lastdsub_ : dsub_;
-                       auto centroid_base = static_cast<std::size_t>(sq) * ksub_ * dsub_ + centroid_idx * dsub_;
+                       auto centroid_base = get_centroid_offset(sq, centroid_idx);
 
                        for (std::int32_t d = 0; d < sub_dim; d++) {
-                               if (offset + d < dim) {
-                                       vec[offset + d] += centroids_[centroid_base + d] * norm;
+                               if (centroid_base + d < centroids_.size() && offset + d < dim) {
+                                       vec[offset + d] += centroids_[centroid_base + d];
                                }
                        }
                        offset += sub_dim;
@@ -219,16 +241,17 @@ public:
 
        auto dot_code(const std::uint8_t *codes, const float *vec, std::int32_t dim) const -> float
        {
+               if (centroids_.empty()) return 0.0f;
                float result = 0.0f;
                std::int32_t offset = 0;
 
                for (std::int32_t sq = 0; sq < nsubq_; sq++) {
                        auto centroid_idx = static_cast<std::size_t>(codes[sq]);
                        auto sub_dim = (sq == nsubq_ - 1) ? lastdsub_ : dsub_;
-                       auto centroid_base = static_cast<std::size_t>(sq) * ksub_ * dsub_ + centroid_idx * dsub_;
+                       auto centroid_base = get_centroid_offset(sq, centroid_idx);
 
                        for (std::int32_t d = 0; d < sub_dim; d++) {
-                               if (offset + d < dim) {
+                               if (centroid_base + d < centroids_.size() && offset + d < dim) {
                                        result += centroids_[centroid_base + d] * vec[offset + d];
                                }
                        }
@@ -243,6 +266,15 @@ public:
        }
 
 private:
+       /* Matches FastText's get_centroids: last sub-quantizer uses lastdsub stride */
+       auto get_centroid_offset(std::int32_t sq, std::size_t centroid_idx) const -> std::size_t
+       {
+               if (sq == nsubq_ - 1) {
+                       return static_cast<std::size_t>(sq) * ksub_ * dsub_ + centroid_idx * lastdsub_;
+               }
+               return (static_cast<std::size_t>(sq) * ksub_ + centroid_idx) * dsub_;
+       }
+
        static constexpr std::int32_t ksub_ = 256; /* number of centroids per sub-quantizer */
        std::int32_t dim_ = 0;
        std::int32_t nsubq_ = 0;
@@ -279,7 +311,7 @@ public:
 
        void add_row_to_vec(float *vec, std::int32_t row, std::int32_t dim) const override
        {
-               if (row < 0 || row >= m_) return;
+               if (!data_ || row < 0 || row >= m_) return;
                auto base = static_cast<std::size_t>(row) * n_;
                auto count = std::min(static_cast<std::int32_t>(n_), dim);
                for (std::int32_t i = 0; i < count; i++) {
@@ -289,7 +321,7 @@ public:
 
        auto dot_row(const float *vec, std::int32_t row, std::int32_t dim) const -> float override
        {
-               if (row < 0 || row >= m_) return 0.0f;
+               if (!data_ || row < 0 || row >= m_) return 0.0f;
                auto base = static_cast<std::size_t>(row) * n_;
                float result = 0.0f;
                auto count = std::min(static_cast<std::int32_t>(n_), dim);
@@ -320,13 +352,11 @@ class quant_matrix final : public matrix_base {
 public:
        quant_matrix(std::int64_t m, std::int64_t n,
                                 std::vector<std::uint8_t> &&codes,
-                                std::vector<float> &&norm_codes_float,
                                 std::vector<std::uint8_t> &&norm_codes,
                                 product_quantizer &&pq,
                                 product_quantizer &&npq)
                : m_(m), n_(n),
                  codes_(std::move(codes)),
-                 norm_codes_float_(std::move(norm_codes_float)),
                  norm_codes_(std::move(norm_codes)),
                  pq_(std::move(pq)),
                  npq_(std::move(npq))
@@ -337,16 +367,20 @@ public:
        {
                if (row < 0 || row >= m_) return;
                auto nsubq = pq_.get_nsubq();
-               auto code_ptr = codes_.data() + static_cast<std::size_t>(row) * nsubq;
-               pq_.add_code(code_ptr, vec, dim);
+               if (nsubq <= 0) return;
+               auto offset = static_cast<std::size_t>(row) * nsubq;
+               if (offset + nsubq > codes_.size()) return;
+               pq_.add_code(codes_.data() + offset, vec, dim);
        }
 
        auto dot_row(const float *vec, std::int32_t row, std::int32_t dim) const -> float override
        {
                if (row < 0 || row >= m_) return 0.0f;
                auto nsubq = pq_.get_nsubq();
-               auto code_ptr = codes_.data() + static_cast<std::size_t>(row) * nsubq;
-               return pq_.dot_code(code_ptr, vec, dim);
+               if (nsubq <= 0) return 0.0f;
+               auto offset = static_cast<std::size_t>(row) * nsubq;
+               if (offset + nsubq > codes_.size()) return 0.0f;
+               return pq_.dot_code(codes_.data() + offset, vec, dim);
        }
 
        auto rows() const -> std::int64_t override
@@ -362,7 +396,6 @@ private:
        std::int64_t m_;
        std::int64_t n_;
        std::vector<std::uint8_t> codes_;
-       std::vector<float> norm_codes_float_;
        std::vector<std::uint8_t> norm_codes_;
        product_quantizer pq_;
        product_quantizer npq_;
@@ -373,13 +406,19 @@ class dictionary {
 public:
        void load(binary_reader &reader, const model_args &args)
        {
+               auto size = reader.read_i32(); /* total entries (words + labels) */
                nwords_ = reader.read_i32();
                nlabels_ = reader.read_i32();
                ntokens_ = reader.read_i64();
 
                auto pruneidx_size = reader.read_i64();
 
-               entries_.resize(nwords_ + nlabels_);
+               if (reader.fail() || size <= 0 || size > MAX_SANE_ENTRIES ||
+                       nwords_ < 0 || nlabels_ < 0 || nwords_ + nlabels_ > size) {
+                       return;
+               }
+
+               entries_.resize(size);
                for (auto &entry: entries_) {
                        entry.word = reader.read_cstring();
                        entry.count = reader.read_i64();
@@ -427,9 +466,12 @@ public:
                return -1;
        }
 
-       auto get_entry(std::int32_t id) const -> const dict_entry &
+       auto get_entry(std::int32_t id) const -> const dict_entry *
        {
-               return entries_.at(id);
+               if (id < 0 || id >= static_cast<std::int32_t>(entries_.size())) {
+                       return nullptr;
+               }
+               return &entries_[id];
        }
 
        auto get_label(std::int32_t id) const -> std::string_view
@@ -447,10 +489,9 @@ public:
                auto ncp = static_cast<int>(positions.size() - 1); /* number of codepoints */
 
                for (int i = 0; i < ncp; i++) {
-                       /* Skip the BOW and EOW positions for single character n-grams */
                        for (int len = minn_; len <= maxn_ && i + len <= ncp; len++) {
-                               /* Skip the full wrapped word "<word>" itself */
-                               if (i == 0 && i + len == ncp) continue;
+                               /* Skip single-char n-grams at BOW/EOW positions */
+                               if (len == 1 && (i == 0 || i + 1 == ncp)) continue;
 
                                auto ngram = word.substr(positions[i], positions[i + len] - positions[i]);
                                auto h = fnv_hash(ngram) % bucket_;
@@ -512,14 +553,14 @@ public:
                auto wid = dict.find(word);
 
                if (wid >= 0) {
-                       auto &entry = dict.get_entry(wid);
-                       if (entry.type == entry_type::word) {
+                       auto *entry = dict.get_entry(wid);
+                       if (entry && entry->type == entry_type::word) {
                                if (args.maxn <= 0) {
                                        ngrams.push_back(wid);
                                }
                                else {
                                        ngrams.insert(ngrams.end(),
-                                                                 entry.subwords.begin(), entry.subwords.end());
+                                                                 entry->subwords.begin(), entry->subwords.end());
                                }
                        }
                }
@@ -537,11 +578,13 @@ public:
        {
                results.clear();
 
-               if (word_ids.empty() || !output_matrix) return;
+               if (word_ids.empty() || !input_matrix || !output_matrix) return;
 
                auto dim = args.dim;
                auto nlabels = dict.get_nlabels();
 
+               if (dim <= 0 || dim > MAX_SANE_DIM || nlabels <= 0) return;
+
                /* Compute hidden layer: average of input rows */
                std::vector<float> hidden(dim, 0.0f);
                for (auto id: word_ids) {
@@ -604,6 +647,10 @@ public:
        void get_word_vector(std::vector<float> &vec, std::string_view word) const
        {
                auto dim = args.dim;
+               if (dim <= 0 || dim > MAX_SANE_DIM || !input_matrix) {
+                       vec.clear();
+                       return;
+               }
                vec.assign(dim, 0.0f);
 
                std::vector<std::int32_t> ngrams;
@@ -629,9 +676,25 @@ static auto load_dense_matrix(binary_reader &reader, const unsigned char *mmap_b
        auto m = reader.read_i64();
        auto n = reader.read_i64();
 
-       auto float_count = static_cast<std::size_t>(m) * n;
+       if (reader.fail() || m <= 0 || m > MAX_SANE_MATRIX_ROWS ||
+               n <= 0 || n > MAX_SANE_DIM) {
+               return nullptr;
+       }
+
+       /* Check for overflow: m * n must fit in size_t */
+       auto um = static_cast<std::size_t>(m);
+       auto un = static_cast<std::size_t>(n);
+       if (um > SIZE_MAX / un) {
+               return nullptr;
+       }
+
+       auto float_count = um * un;
        auto data_ptr = reader.read_floats(float_count);
 
+       if (!data_ptr) {
+               return nullptr;
+       }
+
        /* Check if this pointer is inside the mmap region (zero-copy) or need to copy */
        if (mmap_base != nullptr) {
                return std::make_unique<dense_matrix>(data_ptr, m, n);
@@ -643,48 +706,57 @@ static auto load_dense_matrix(binary_reader &reader, const unsigned char *mmap_b
 }
 
 /* --- Load a quantized matrix from binary data --- */
+/* FastText QMatrix binary format:
+ *   qnorm (bool), m (int64), n (int64), codesize (int32),
+ *   codes[codesize] (uint8), PQ::load(),
+ *   [if qnorm: norm_codes[m] (uint8), NPQ::load()]
+ */
 static auto load_quant_matrix(binary_reader &reader)
        -> std::unique_ptr<quant_matrix>
 {
        auto qnorm = reader.read_bool();
        auto m = reader.read_i64();
        auto n = reader.read_i64();
+       auto codesize = reader.read_i32();
 
-       /* Read codes_count = m * pq.nsubq (but we read PQ first to get nsubq) */
+       if (reader.fail() || m <= 0 || m > MAX_SANE_MATRIX_ROWS ||
+               n <= 0 || n > MAX_SANE_DIM ||
+               codesize <= 0 || codesize > MAX_SANE_CODESIZE) {
+               return nullptr;
+       }
+
+       /* Read codes BEFORE PQ */
+       std::vector<std::uint8_t> codes(codesize);
+       for (std::int32_t i = 0; i < codesize; i++) {
+               codes[i] = reader.read_u8();
+       }
 
        /* Read PQ */
        product_quantizer pq;
-       pq.load(reader);
-
-       auto nsubq = pq.get_nsubq();
-       auto codes_size = static_cast<std::size_t>(m) * nsubq;
-       std::vector<std::uint8_t> codes(codes_size);
-       for (std::size_t i = 0; i < codes_size; i++) {
-               codes[i] = reader.read_u8();
+       if (!pq.load(reader)) {
+               return nullptr;
        }
 
-       std::vector<float> norm_codes_float;
        std::vector<std::uint8_t> norm_codes;
        product_quantizer npq;
 
        if (qnorm) {
-               npq.load(reader);
-               auto norm_nsubq = npq.get_nsubq();
-               auto norm_codes_size = static_cast<std::size_t>(m) * norm_nsubq;
-               norm_codes.resize(norm_codes_size);
-               for (std::size_t i = 0; i < norm_codes_size; i++) {
+               /* Read norm codes (one per row) then norm PQ */
+               norm_codes.resize(m);
+               for (std::int64_t i = 0; i < m; i++) {
                        norm_codes[i] = reader.read_u8();
                }
+               if (!npq.load(reader)) {
+                       return nullptr;
+               }
        }
-       else {
-               /* Read float norms */
-               norm_codes_float.resize(m);
-               auto ptr = reader.read_floats(m);
-               std::copy(ptr, ptr + m, norm_codes_float.begin());
+
+       if (reader.fail()) {
+               return nullptr;
        }
 
        return std::make_unique<quant_matrix>(m, n,
-                                                                                 std::move(codes), std::move(norm_codes_float),
+                                                                                 std::move(codes),
                                                                                  std::move(norm_codes),
                                                                                  std::move(pq), std::move(npq));
 }
@@ -761,6 +833,17 @@ auto fasttext_model::load(const char *path) -> tl::expected<fasttext_model, rspa
                                EINVAL));
        }
 
+       /* Validate model args sanity */
+       if (args.dim <= 0 || args.dim > MAX_SANE_DIM ||
+               args.bucket < 0 || args.bucket > MAX_SANE_BUCKET ||
+               args.minn < 0 || args.maxn < 0 || args.maxn > 100) {
+               return tl::make_unexpected(
+                       rspamd::util::error(
+                               fmt::format("invalid fasttext model parameters in '{}': dim={} bucket={} minn={} maxn={}",
+                                                       path, args.dim, args.bucket, args.minn, args.maxn),
+                               EINVAL));
+       }
+
        /* Read dictionary */
        impl->dict.load(reader, args);
 
@@ -782,32 +865,27 @@ auto fasttext_model::load(const char *path) -> tl::expected<fasttext_model, rspa
                impl->input_matrix = load_dense_matrix(reader, base);
        }
 
-       if (reader.fail()) {
+       if (!impl->input_matrix || reader.fail()) {
                return tl::make_unexpected(
                        rspamd::util::error(
-                               fmt::format("truncated fasttext input matrix in '{}'", path),
+                               fmt::format("failed to load fasttext input matrix in '{}'", path),
                                EINVAL));
        }
 
-       /* Read output matrix - check if quantized */
-       if (!quant_input) {
-               /* Dense output */
-               impl->output_matrix = load_dense_matrix(reader, nullptr);
+       /* Read output matrix — FastText always writes a qout bool here */
+       auto quant_output = reader.read_bool();
+
+       if (quant_input && quant_output) {
+               impl->output_matrix = load_quant_matrix(reader);
        }
        else {
-               auto quant_output = reader.read_bool();
-               if (quant_output) {
-                       impl->output_matrix = load_quant_matrix(reader);
-               }
-               else {
-                       impl->output_matrix = load_dense_matrix(reader, nullptr);
-               }
+               impl->output_matrix = load_dense_matrix(reader, nullptr);
        }
 
-       if (reader.fail()) {
+       if (!impl->output_matrix || reader.fail()) {
                return tl::make_unexpected(
                        rspamd::util::error(
-                               fmt::format("truncated fasttext output matrix in '{}'", path),
+                               fmt::format("failed to load fasttext output matrix in '{}'", path),
                                EINVAL));
        }
 
index 7df881b23a5fae2cb179da1a09992c4af0e6a64c..6a92e9871d903c79087c352d84db610006701123 100644 (file)
@@ -170,13 +170,13 @@ lua_fasttext_model_get_word_vector(lua_State *L)
                return 1;
        }
 
-       auto dim = model->model->get_dimension();
        std::vector<float> vec;
 
        model->model->get_word_vector(vec, std::string_view{word});
 
-       lua_createtable(L, dim, 0);
-       for (std::int32_t i = 0; i < dim; i++) {
+       auto vec_size = static_cast<std::int32_t>(vec.size());
+       lua_createtable(L, vec_size, 0);
+       for (std::int32_t i = 0; i < vec_size; i++) {
                lua_pushnumber(L, static_cast<double>(vec[i]));
                lua_rawseti(L, -2, i + 1);
        }
@@ -205,6 +205,11 @@ lua_fasttext_model_get_sentence_vector(lua_State *L)
        luaL_argcheck(L, lua_istable(L, 2), 2, "'table' of words expected");
 
        auto dim = model->model->get_dimension();
+       if (dim <= 0 || dim > 4096) {
+               lua_pushnil(L);
+               return 1;
+       }
+
        std::vector<float> sentence_vec(dim, 0.0f);
        std::vector<float> word_vec;
        int count = 0;
@@ -219,7 +224,8 @@ lua_fasttext_model_get_sentence_vector(lua_State *L)
                        const char *w = lua_tolstring(L, -1, &len);
                        if (len > 0) {
                                model->model->get_word_vector(word_vec, std::string_view{w, len});
-                               for (std::int32_t d = 0; d < dim; d++) {
+                               auto wv_size = std::min(dim, static_cast<std::int32_t>(word_vec.size()));
+                               for (std::int32_t d = 0; d < wv_size; d++) {
                                        sentence_vec[d] += word_vec[d];
                                }
                                count++;