namespace {
+/* Sanity limits for untrusted values read from model files */
+constexpr std::int32_t MAX_SANE_DIM = 4096;
+constexpr std::int32_t MAX_SANE_ENTRIES = 50'000'000;
+constexpr std::int64_t MAX_SANE_MATRIX_ROWS = 500'000'000;
+constexpr std::int32_t MAX_SANE_CODESIZE = 500'000'000;
+constexpr std::int32_t MAX_SANE_BUCKET = 100'000'000;
+constexpr std::size_t MAX_SANE_STRING = 1024;
+
/* --- Binary reader: a cursor over memory-mapped data --- */
/* Uses a fail-bit pattern instead of exceptions: once any read overflows,
* the reader enters a failed state and all subsequent reads return zeroes.
auto ch = data_[pos_++];
if (ch == 0) break;
result.push_back(static_cast<char>(ch));
+ if (result.size() > MAX_SANE_STRING) {
+ failed_ = true;
+ return {};
+ }
}
return result;
}
auto read_floats(std::size_t count) -> const float *
{
+ if (count > SIZE_MAX / sizeof(float)) {
+ failed_ = true;
+ return nullptr;
+ }
auto bytes = count * sizeof(float);
if (!ensure(bytes)) return nullptr;
auto ptr = reinterpret_cast<const float *>(data_ + pos_);
/* --- Product Quantizer --- */
class product_quantizer {
public:
- void load(binary_reader &reader)
+ auto load(binary_reader &reader) -> bool
{
dim_ = reader.read_i32();
nsubq_ = reader.read_i32();
dsub_ = reader.read_i32();
lastdsub_ = reader.read_i32();
- auto centroid_count = static_cast<std::size_t>(nsubq_) * ksub_;
- auto total_floats = centroid_count * dsub_;
- /* Actually, centroids are stored as nsubq * ksub * dsub floats,
- * but the last subquantizer may use lastdsub. FastText stores
- * all centroids with dsub stride. */
+ if (reader.fail() || dim_ <= 0 || dim_ > MAX_SANE_DIM ||
+ nsubq_ <= 0 || dsub_ <= 0 || lastdsub_ <= 0) {
+ return false;
+ }
+
+ /* Centroids are stored as dim * ksub floats (not nsubq * ksub * dsub) */
+ auto total_floats = static_cast<std::size_t>(dim_) * ksub_;
auto ptr = reader.read_floats(total_floats);
+ if (!ptr) {
+ return false;
+ }
centroids_.assign(ptr, ptr + total_floats);
+ return true;
}
void add_code(const std::uint8_t *codes, float *vec, std::int32_t dim) const
{
- float norm = 1.0f;
+ if (centroids_.empty()) return;
std::int32_t offset = 0;
for (std::int32_t sq = 0; sq < nsubq_; sq++) {
auto centroid_idx = static_cast<std::size_t>(codes[sq]);
auto sub_dim = (sq == nsubq_ - 1) ? lastdsub_ : dsub_;
- auto centroid_base = static_cast<std::size_t>(sq) * ksub_ * dsub_ + centroid_idx * dsub_;
+ auto centroid_base = get_centroid_offset(sq, centroid_idx);
for (std::int32_t d = 0; d < sub_dim; d++) {
- if (offset + d < dim) {
- vec[offset + d] += centroids_[centroid_base + d] * norm;
+ if (centroid_base + d < centroids_.size() && offset + d < dim) {
+ vec[offset + d] += centroids_[centroid_base + d];
}
}
offset += sub_dim;
auto dot_code(const std::uint8_t *codes, const float *vec, std::int32_t dim) const -> float
{
+ if (centroids_.empty()) return 0.0f;
float result = 0.0f;
std::int32_t offset = 0;
for (std::int32_t sq = 0; sq < nsubq_; sq++) {
auto centroid_idx = static_cast<std::size_t>(codes[sq]);
auto sub_dim = (sq == nsubq_ - 1) ? lastdsub_ : dsub_;
- auto centroid_base = static_cast<std::size_t>(sq) * ksub_ * dsub_ + centroid_idx * dsub_;
+ auto centroid_base = get_centroid_offset(sq, centroid_idx);
for (std::int32_t d = 0; d < sub_dim; d++) {
- if (offset + d < dim) {
+ if (centroid_base + d < centroids_.size() && offset + d < dim) {
result += centroids_[centroid_base + d] * vec[offset + d];
}
}
}
private:
+ /* Matches FastText's get_centroids: last sub-quantizer uses lastdsub stride */
+ auto get_centroid_offset(std::int32_t sq, std::size_t centroid_idx) const -> std::size_t
+ {
+ if (sq == nsubq_ - 1) {
+ return static_cast<std::size_t>(sq) * ksub_ * dsub_ + centroid_idx * lastdsub_;
+ }
+ return (static_cast<std::size_t>(sq) * ksub_ + centroid_idx) * dsub_;
+ }
+
static constexpr std::int32_t ksub_ = 256; /* number of centroids per sub-quantizer */
std::int32_t dim_ = 0;
std::int32_t nsubq_ = 0;
void add_row_to_vec(float *vec, std::int32_t row, std::int32_t dim) const override
{
- if (row < 0 || row >= m_) return;
+ if (!data_ || row < 0 || row >= m_) return;
auto base = static_cast<std::size_t>(row) * n_;
auto count = std::min(static_cast<std::int32_t>(n_), dim);
for (std::int32_t i = 0; i < count; i++) {
auto dot_row(const float *vec, std::int32_t row, std::int32_t dim) const -> float override
{
- if (row < 0 || row >= m_) return 0.0f;
+ if (!data_ || row < 0 || row >= m_) return 0.0f;
auto base = static_cast<std::size_t>(row) * n_;
float result = 0.0f;
auto count = std::min(static_cast<std::int32_t>(n_), dim);
public:
quant_matrix(std::int64_t m, std::int64_t n,
std::vector<std::uint8_t> &&codes,
- std::vector<float> &&norm_codes_float,
std::vector<std::uint8_t> &&norm_codes,
product_quantizer &&pq,
product_quantizer &&npq)
: m_(m), n_(n),
codes_(std::move(codes)),
- norm_codes_float_(std::move(norm_codes_float)),
norm_codes_(std::move(norm_codes)),
pq_(std::move(pq)),
npq_(std::move(npq))
{
if (row < 0 || row >= m_) return;
auto nsubq = pq_.get_nsubq();
- auto code_ptr = codes_.data() + static_cast<std::size_t>(row) * nsubq;
- pq_.add_code(code_ptr, vec, dim);
+ if (nsubq <= 0) return;
+ auto offset = static_cast<std::size_t>(row) * nsubq;
+ if (offset + nsubq > codes_.size()) return;
+ pq_.add_code(codes_.data() + offset, vec, dim);
}
auto dot_row(const float *vec, std::int32_t row, std::int32_t dim) const -> float override
{
if (row < 0 || row >= m_) return 0.0f;
auto nsubq = pq_.get_nsubq();
- auto code_ptr = codes_.data() + static_cast<std::size_t>(row) * nsubq;
- return pq_.dot_code(code_ptr, vec, dim);
+ if (nsubq <= 0) return 0.0f;
+ auto offset = static_cast<std::size_t>(row) * nsubq;
+ if (offset + nsubq > codes_.size()) return 0.0f;
+ return pq_.dot_code(codes_.data() + offset, vec, dim);
}
auto rows() const -> std::int64_t override
std::int64_t m_;
std::int64_t n_;
std::vector<std::uint8_t> codes_;
- std::vector<float> norm_codes_float_;
std::vector<std::uint8_t> norm_codes_;
product_quantizer pq_;
product_quantizer npq_;
public:
void load(binary_reader &reader, const model_args &args)
{
+ auto size = reader.read_i32(); /* total entries (words + labels) */
nwords_ = reader.read_i32();
nlabels_ = reader.read_i32();
ntokens_ = reader.read_i64();
auto pruneidx_size = reader.read_i64();
- entries_.resize(nwords_ + nlabels_);
+ if (reader.fail() || size <= 0 || size > MAX_SANE_ENTRIES ||
+ nwords_ < 0 || nlabels_ < 0 || nwords_ + nlabels_ > size) {
+ return;
+ }
+
+ entries_.resize(size);
for (auto &entry: entries_) {
entry.word = reader.read_cstring();
entry.count = reader.read_i64();
return -1;
}
- auto get_entry(std::int32_t id) const -> const dict_entry &
+ auto get_entry(std::int32_t id) const -> const dict_entry *
{
- return entries_.at(id);
+ if (id < 0 || id >= static_cast<std::int32_t>(entries_.size())) {
+ return nullptr;
+ }
+ return &entries_[id];
}
auto get_label(std::int32_t id) const -> std::string_view
auto ncp = static_cast<int>(positions.size() - 1); /* number of codepoints */
for (int i = 0; i < ncp; i++) {
- /* Skip the BOW and EOW positions for single character n-grams */
for (int len = minn_; len <= maxn_ && i + len <= ncp; len++) {
- /* Skip the full wrapped word "<word>" itself */
- if (i == 0 && i + len == ncp) continue;
+ /* Skip single-char n-grams at BOW/EOW positions */
+ if (len == 1 && (i == 0 || i + 1 == ncp)) continue;
auto ngram = word.substr(positions[i], positions[i + len] - positions[i]);
auto h = fnv_hash(ngram) % bucket_;
auto wid = dict.find(word);
if (wid >= 0) {
- auto &entry = dict.get_entry(wid);
- if (entry.type == entry_type::word) {
+ auto *entry = dict.get_entry(wid);
+ if (entry && entry->type == entry_type::word) {
if (args.maxn <= 0) {
ngrams.push_back(wid);
}
else {
ngrams.insert(ngrams.end(),
- entry.subwords.begin(), entry.subwords.end());
+ entry->subwords.begin(), entry->subwords.end());
}
}
}
{
results.clear();
- if (word_ids.empty() || !output_matrix) return;
+ if (word_ids.empty() || !input_matrix || !output_matrix) return;
auto dim = args.dim;
auto nlabels = dict.get_nlabels();
+ if (dim <= 0 || dim > MAX_SANE_DIM || nlabels <= 0) return;
+
/* Compute hidden layer: average of input rows */
std::vector<float> hidden(dim, 0.0f);
for (auto id: word_ids) {
void get_word_vector(std::vector<float> &vec, std::string_view word) const
{
auto dim = args.dim;
+ if (dim <= 0 || dim > MAX_SANE_DIM || !input_matrix) {
+ vec.clear();
+ return;
+ }
vec.assign(dim, 0.0f);
std::vector<std::int32_t> ngrams;
auto m = reader.read_i64();
auto n = reader.read_i64();
- auto float_count = static_cast<std::size_t>(m) * n;
+ if (reader.fail() || m <= 0 || m > MAX_SANE_MATRIX_ROWS ||
+ n <= 0 || n > MAX_SANE_DIM) {
+ return nullptr;
+ }
+
+ /* Check for overflow: m * n must fit in size_t */
+ auto um = static_cast<std::size_t>(m);
+ auto un = static_cast<std::size_t>(n);
+ if (um > SIZE_MAX / un) {
+ return nullptr;
+ }
+
+ auto float_count = um * un;
auto data_ptr = reader.read_floats(float_count);
+ if (!data_ptr) {
+ return nullptr;
+ }
+
/* Check if this pointer is inside the mmap region (zero-copy) or need to copy */
if (mmap_base != nullptr) {
return std::make_unique<dense_matrix>(data_ptr, m, n);
}
/* --- Load a quantized matrix from binary data --- */
+/* FastText QMatrix binary format:
+ * qnorm (bool), m (int64), n (int64), codesize (int32),
+ * codes[codesize] (uint8), PQ::load(),
+ * [if qnorm: norm_codes[m] (uint8), NPQ::load()]
+ */
static auto load_quant_matrix(binary_reader &reader)
-> std::unique_ptr<quant_matrix>
{
auto qnorm = reader.read_bool();
auto m = reader.read_i64();
auto n = reader.read_i64();
+ auto codesize = reader.read_i32();
- /* Read codes_count = m * pq.nsubq (but we read PQ first to get nsubq) */
+ if (reader.fail() || m <= 0 || m > MAX_SANE_MATRIX_ROWS ||
+ n <= 0 || n > MAX_SANE_DIM ||
+ codesize <= 0 || codesize > MAX_SANE_CODESIZE) {
+ return nullptr;
+ }
+
+ /* Read codes BEFORE PQ */
+ std::vector<std::uint8_t> codes(codesize);
+ for (std::int32_t i = 0; i < codesize; i++) {
+ codes[i] = reader.read_u8();
+ }
/* Read PQ */
product_quantizer pq;
- pq.load(reader);
-
- auto nsubq = pq.get_nsubq();
- auto codes_size = static_cast<std::size_t>(m) * nsubq;
- std::vector<std::uint8_t> codes(codes_size);
- for (std::size_t i = 0; i < codes_size; i++) {
- codes[i] = reader.read_u8();
+ if (!pq.load(reader)) {
+ return nullptr;
}
- std::vector<float> norm_codes_float;
std::vector<std::uint8_t> norm_codes;
product_quantizer npq;
if (qnorm) {
- npq.load(reader);
- auto norm_nsubq = npq.get_nsubq();
- auto norm_codes_size = static_cast<std::size_t>(m) * norm_nsubq;
- norm_codes.resize(norm_codes_size);
- for (std::size_t i = 0; i < norm_codes_size; i++) {
+ /* Read norm codes (one per row) then norm PQ */
+ norm_codes.resize(m);
+ for (std::int64_t i = 0; i < m; i++) {
norm_codes[i] = reader.read_u8();
}
+ if (!npq.load(reader)) {
+ return nullptr;
+ }
}
- else {
- /* Read float norms */
- norm_codes_float.resize(m);
- auto ptr = reader.read_floats(m);
- std::copy(ptr, ptr + m, norm_codes_float.begin());
+
+ if (reader.fail()) {
+ return nullptr;
}
return std::make_unique<quant_matrix>(m, n,
- std::move(codes), std::move(norm_codes_float),
+ std::move(codes),
std::move(norm_codes),
std::move(pq), std::move(npq));
}
EINVAL));
}
+ /* Validate model args sanity */
+ if (args.dim <= 0 || args.dim > MAX_SANE_DIM ||
+ args.bucket < 0 || args.bucket > MAX_SANE_BUCKET ||
+ args.minn < 0 || args.maxn < 0 || args.maxn > 100) {
+ return tl::make_unexpected(
+ rspamd::util::error(
+ fmt::format("invalid fasttext model parameters in '{}': dim={} bucket={} minn={} maxn={}",
+ path, args.dim, args.bucket, args.minn, args.maxn),
+ EINVAL));
+ }
+
/* Read dictionary */
impl->dict.load(reader, args);
impl->input_matrix = load_dense_matrix(reader, base);
}
- if (reader.fail()) {
+ if (!impl->input_matrix || reader.fail()) {
return tl::make_unexpected(
rspamd::util::error(
- fmt::format("truncated fasttext input matrix in '{}'", path),
+ fmt::format("failed to load fasttext input matrix in '{}'", path),
EINVAL));
}
- /* Read output matrix - check if quantized */
- if (!quant_input) {
- /* Dense output */
- impl->output_matrix = load_dense_matrix(reader, nullptr);
+ /* Read output matrix — FastText always writes a qout bool here */
+ auto quant_output = reader.read_bool();
+
+ if (quant_input && quant_output) {
+ impl->output_matrix = load_quant_matrix(reader);
}
else {
- auto quant_output = reader.read_bool();
- if (quant_output) {
- impl->output_matrix = load_quant_matrix(reader);
- }
- else {
- impl->output_matrix = load_dense_matrix(reader, nullptr);
- }
+ impl->output_matrix = load_dense_matrix(reader, nullptr);
}
- if (reader.fail()) {
+ if (!impl->output_matrix || reader.fail()) {
return tl::make_unexpected(
rspamd::util::error(
- fmt::format("truncated fasttext output matrix in '{}'", path),
+ fmt::format("failed to load fasttext output matrix in '{}'", path),
EINVAL));
}