namespace {
/* --- Binary reader: a cursor over memory-mapped data --- */
+/* Uses a fail-bit pattern instead of exceptions: once any read overflows,
+ * the reader enters a failed state and all subsequent reads return zeroes.
+ * Callers check fail() after a sequence of reads. */
class binary_reader {
public:
binary_reader(const unsigned char *data, std::size_t size)
- : data_(data), size_(size), pos_(0)
+ : data_(data), size_(size), pos_(0), failed_(false)
{
}
}
auto remaining() const -> std::size_t
{
- return size_ - pos_;
+ return failed_ ? 0 : size_ - pos_;
}
-
- void seek(std::size_t pos)
+ auto fail() const -> bool
{
- if (pos > size_) {
- throw std::out_of_range("binary_reader::seek past end");
- }
- pos_ = pos;
+ return failed_;
}
void skip(std::size_t n)
{
- check(n);
+ if (!ensure(n)) return;
pos_ += n;
}
auto read_i32() -> std::int32_t
{
- check(4);
+ if (!ensure(4)) return 0;
std::int32_t v;
std::memcpy(&v, data_ + pos_, 4);
pos_ += 4;
auto read_i64() -> std::int64_t
{
- check(8);
+ if (!ensure(8)) return 0;
std::int64_t v;
std::memcpy(&v, data_ + pos_, 8);
pos_ += 8;
auto read_f64() -> double
{
- check(8);
+ if (!ensure(8)) return 0.0;
double v;
std::memcpy(&v, data_ + pos_, 8);
pos_ += 8;
auto read_u8() -> std::uint8_t
{
- check(1);
+ if (!ensure(1)) return 0;
auto v = data_[pos_];
pos_ += 1;
return v;
auto read_floats(std::size_t count) -> const float *
{
auto bytes = count * sizeof(float);
- check(bytes);
+ if (!ensure(bytes)) return nullptr;
auto ptr = reinterpret_cast<const float *>(data_ + pos_);
pos_ += bytes;
return ptr;
}
- auto current_ptr() const -> const unsigned char *
- {
- return data_ + pos_;
- }
-
private:
- void check(std::size_t n) const
+ auto ensure(std::size_t n) -> bool
{
- if (pos_ + n > size_) {
- throw std::out_of_range(
- fmt::format("binary_reader: need {} bytes at offset {}, but only {} available",
- n, pos_, size_ - pos_));
+ if (failed_ || pos_ + n > size_) {
+ failed_ = true;
+ return false;
}
+ return true;
}
const unsigned char *data_;
std::size_t size_;
std::size_t pos_;
+ bool failed_;
};
/* --- FNV-1a hash matching FastText's implementation --- */
binary_reader reader(base, file_size);
- try {
- /* Read and validate magic */
- auto magic = reader.read_i32();
- if (magic != FASTTEXT_FILEFORMAT_MAGIC) {
- return tl::make_unexpected(
- rspamd::util::error(
- fmt::format("invalid fasttext magic: {} (expected {})", magic, FASTTEXT_FILEFORMAT_MAGIC),
- EINVAL));
- }
+ /* Read and validate magic */
+ auto magic = reader.read_i32();
+ if (reader.fail() || magic != FASTTEXT_FILEFORMAT_MAGIC) {
+ return tl::make_unexpected(
+ rspamd::util::error(
+ fmt::format("invalid fasttext magic: {} (expected {})", magic, FASTTEXT_FILEFORMAT_MAGIC),
+ EINVAL));
+ }
- /* Read and validate version */
- auto version = reader.read_i32();
- if (version > FASTTEXT_VERSION) {
- return tl::make_unexpected(
- rspamd::util::error(
- fmt::format("unsupported fasttext version: {} (max {})", version, FASTTEXT_VERSION),
- EINVAL));
- }
+ /* Read and validate version */
+ auto version = reader.read_i32();
+ if (reader.fail() || version > FASTTEXT_VERSION) {
+ return tl::make_unexpected(
+ rspamd::util::error(
+ fmt::format("unsupported fasttext version: {} (max {})", version, FASTTEXT_VERSION),
+ EINVAL));
+ }
- auto impl = std::make_unique<fasttext_model_impl>();
-
- /* Read model args (52 bytes of packed data) */
- auto &args = impl->args;
- args.dim = reader.read_i32();
- args.ws = reader.read_i32();
- args.epoch = reader.read_i32();
- args.minCount = reader.read_i32();
- args.neg = reader.read_i32();
- args.wordNgrams = reader.read_i32();
- args.loss = reader.read_i32();
- args.model = static_cast<model_name>(reader.read_i32());
- args.bucket = reader.read_i32();
- args.minn = reader.read_i32();
- args.maxn = reader.read_i32();
- args.lrUpdateRate = reader.read_i32();
- args.t = reader.read_f64();
-
- /* Read dictionary */
- impl->dict.load(reader, args);
-
- /* Determine if input matrix is quantized */
- auto quant_input = reader.read_bool();
-
- if (quant_input) {
- impl->input_matrix = load_quant_matrix(reader);
- }
- else {
- /* Dense input matrix - pointer into mmap region (zero-copy) */
- impl->input_matrix = load_dense_matrix(reader, base);
- }
+ auto impl = std::make_unique<fasttext_model_impl>();
+
+ /* Read model args (52 bytes of packed data) */
+ auto &args = impl->args;
+ args.dim = reader.read_i32();
+ args.ws = reader.read_i32();
+ args.epoch = reader.read_i32();
+ args.minCount = reader.read_i32();
+ args.neg = reader.read_i32();
+ args.wordNgrams = reader.read_i32();
+ args.loss = reader.read_i32();
+ args.model = static_cast<model_name>(reader.read_i32());
+ args.bucket = reader.read_i32();
+ args.minn = reader.read_i32();
+ args.maxn = reader.read_i32();
+ args.lrUpdateRate = reader.read_i32();
+ args.t = reader.read_f64();
+
+ if (reader.fail()) {
+ return tl::make_unexpected(
+ rspamd::util::error(
+ fmt::format("truncated fasttext model header in '{}'", path),
+ EINVAL));
+ }
- /* Read output matrix - check if quantized */
- if (!quant_input) {
- /* Dense output */
- impl->output_matrix = load_dense_matrix(reader, nullptr);
+ /* Read dictionary */
+ impl->dict.load(reader, args);
+
+ if (reader.fail()) {
+ return tl::make_unexpected(
+ rspamd::util::error(
+ fmt::format("truncated fasttext dictionary in '{}'", path),
+ EINVAL));
+ }
+
+ /* Determine if input matrix is quantized */
+ auto quant_input = reader.read_bool();
+
+ if (quant_input) {
+ impl->input_matrix = load_quant_matrix(reader);
+ }
+ else {
+ /* Dense input matrix - pointer into mmap region (zero-copy) */
+ impl->input_matrix = load_dense_matrix(reader, base);
+ }
+
+ if (reader.fail()) {
+ return tl::make_unexpected(
+ rspamd::util::error(
+ fmt::format("truncated fasttext input matrix in '{}'", path),
+ EINVAL));
+ }
+
+ /* Read output matrix - check if quantized */
+ if (!quant_input) {
+ /* Dense output */
+ impl->output_matrix = load_dense_matrix(reader, nullptr);
+ }
+ else {
+ auto quant_output = reader.read_bool();
+ if (quant_output) {
+ impl->output_matrix = load_quant_matrix(reader);
}
else {
- auto quant_output = reader.read_bool();
- if (quant_output) {
- impl->output_matrix = load_quant_matrix(reader);
- }
- else {
- impl->output_matrix = load_dense_matrix(reader, nullptr);
- }
+ impl->output_matrix = load_dense_matrix(reader, nullptr);
}
+ }
- /* Store the mmap to keep it alive */
- impl->mmap_file.emplace(std::move(*mmap_result));
-
- return fasttext_model(std::move(impl));
- } catch (const std::exception &e) {
+ if (reader.fail()) {
return tl::make_unexpected(
rspamd::util::error(
- fmt::format("failed to parse fasttext model '{}': {}", path, e.what()),
+ fmt::format("truncated fasttext output matrix in '{}'", path),
EINVAL));
}
+
+ /* Store the mmap to keep it alive */
+ impl->mmap_file.emplace(std::move(*mmap_result));
+
+ return fasttext_model(std::move(impl));
}
void fasttext_model::word2vec(std::string_view word, std::vector<std::int32_t> &ngrams) const