]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Fasttext shim: addressing review comments
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 21 Feb 2026 16:46:19 +0000 (16:46 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 21 Feb 2026 16:46:19 +0000 (16:46 +0000)
- Use ICU U8_NEXT for UTF-8 iteration instead of handcrafted code
- Replace exception-based error handling with fail-bit pattern in
  binary_reader, propagating errors via tl::expected
- Replace std::sort with std::reverse after min-heap extraction

src/libserver/fasttext/fasttext_shim.cxx

index 2f6c0e800cadac686d6fd4d79f64e117c7bf17cf..6a51a6f57c6eb99a139de4c7ae040d0d14f183b4 100644 (file)
@@ -34,10 +34,13 @@ namespace rspamd::fasttext {
 namespace {
 
 /* --- Binary reader: a cursor over memory-mapped data --- */
+/* Uses a fail-bit pattern instead of exceptions: once any read overflows,
+ * the reader enters a failed state and all subsequent reads return zeroes.
+ * Callers check fail() after a sequence of reads. */
 class binary_reader {
 public:
        binary_reader(const unsigned char *data, std::size_t size)
-               : data_(data), size_(size), pos_(0)
+               : data_(data), size_(size), pos_(0), failed_(false)
        {
        }
 
@@ -47,26 +50,22 @@ public:
        }
        auto remaining() const -> std::size_t
        {
-               return size_ - pos_;
+               return failed_ ? 0 : size_ - pos_;
        }
-
-       void seek(std::size_t pos)
+       auto fail() const -> bool
        {
-               if (pos > size_) {
-                       throw std::out_of_range("binary_reader::seek past end");
-               }
-               pos_ = pos;
+               return failed_;
        }
 
        void skip(std::size_t n)
        {
-               check(n);
+               if (!ensure(n)) return;
                pos_ += n;
        }
 
        auto read_i32() -> std::int32_t
        {
-               check(4);
+               if (!ensure(4)) return 0;
                std::int32_t v;
                std::memcpy(&v, data_ + pos_, 4);
                pos_ += 4;
@@ -75,7 +74,7 @@ public:
 
        auto read_i64() -> std::int64_t
        {
-               check(8);
+               if (!ensure(8)) return 0;
                std::int64_t v;
                std::memcpy(&v, data_ + pos_, 8);
                pos_ += 8;
@@ -84,7 +83,7 @@ public:
 
        auto read_f64() -> double
        {
-               check(8);
+               if (!ensure(8)) return 0.0;
                double v;
                std::memcpy(&v, data_ + pos_, 8);
                pos_ += 8;
@@ -93,7 +92,7 @@ public:
 
        auto read_u8() -> std::uint8_t
        {
-               check(1);
+               if (!ensure(1)) return 0;
                auto v = data_[pos_];
                pos_ += 1;
                return v;
@@ -118,30 +117,26 @@ public:
        auto read_floats(std::size_t count) -> const float *
        {
                auto bytes = count * sizeof(float);
-               check(bytes);
+               if (!ensure(bytes)) return nullptr;
                auto ptr = reinterpret_cast<const float *>(data_ + pos_);
                pos_ += bytes;
                return ptr;
        }
 
-       auto current_ptr() const -> const unsigned char *
-       {
-               return data_ + pos_;
-       }
-
 private:
-       void check(std::size_t n) const
+       auto ensure(std::size_t n) -> bool
        {
-               if (pos_ + n > size_) {
-                       throw std::out_of_range(
-                               fmt::format("binary_reader: need {} bytes at offset {}, but only {} available",
-                                                       n, pos_, size_ - pos_));
+               if (failed_ || pos_ + n > size_) {
+                       failed_ = true;
+                       return false;
                }
+               return true;
        }
 
        const unsigned char *data_;
        std::size_t size_;
        std::size_t pos_;
+       bool failed_;
 };
 
 /* --- FNV-1a hash matching FastText's implementation --- */
@@ -723,82 +718,103 @@ auto fasttext_model::load(const char *path) -> tl::expected<fasttext_model, rspa
 
        binary_reader reader(base, file_size);
 
-       try {
-               /* Read and validate magic */
-               auto magic = reader.read_i32();
-               if (magic != FASTTEXT_FILEFORMAT_MAGIC) {
-                       return tl::make_unexpected(
-                               rspamd::util::error(
-                                       fmt::format("invalid fasttext magic: {} (expected {})", magic, FASTTEXT_FILEFORMAT_MAGIC),
-                                       EINVAL));
-               }
+       /* Read and validate magic */
+       auto magic = reader.read_i32();
+       if (reader.fail() || magic != FASTTEXT_FILEFORMAT_MAGIC) {
+               return tl::make_unexpected(
+                       rspamd::util::error(
+                               fmt::format("invalid fasttext magic: {} (expected {})", magic, FASTTEXT_FILEFORMAT_MAGIC),
+                               EINVAL));
+       }
 
-               /* Read and validate version */
-               auto version = reader.read_i32();
-               if (version > FASTTEXT_VERSION) {
-                       return tl::make_unexpected(
-                               rspamd::util::error(
-                                       fmt::format("unsupported fasttext version: {} (max {})", version, FASTTEXT_VERSION),
-                                       EINVAL));
-               }
+       /* Read and validate version */
+       auto version = reader.read_i32();
+       if (reader.fail() || version > FASTTEXT_VERSION) {
+               return tl::make_unexpected(
+                       rspamd::util::error(
+                               fmt::format("unsupported fasttext version: {} (max {})", version, FASTTEXT_VERSION),
+                               EINVAL));
+       }
 
-               auto impl = std::make_unique<fasttext_model_impl>();
-
-               /* Read model args (52 bytes of packed data) */
-               auto &args = impl->args;
-               args.dim = reader.read_i32();
-               args.ws = reader.read_i32();
-               args.epoch = reader.read_i32();
-               args.minCount = reader.read_i32();
-               args.neg = reader.read_i32();
-               args.wordNgrams = reader.read_i32();
-               args.loss = reader.read_i32();
-               args.model = static_cast<model_name>(reader.read_i32());
-               args.bucket = reader.read_i32();
-               args.minn = reader.read_i32();
-               args.maxn = reader.read_i32();
-               args.lrUpdateRate = reader.read_i32();
-               args.t = reader.read_f64();
-
-               /* Read dictionary */
-               impl->dict.load(reader, args);
-
-               /* Determine if input matrix is quantized */
-               auto quant_input = reader.read_bool();
-
-               if (quant_input) {
-                       impl->input_matrix = load_quant_matrix(reader);
-               }
-               else {
-                       /* Dense input matrix - pointer into mmap region (zero-copy) */
-                       impl->input_matrix = load_dense_matrix(reader, base);
-               }
+       auto impl = std::make_unique<fasttext_model_impl>();
+
+       /* Read model args (52 bytes of packed data) */
+       auto &args = impl->args;
+       args.dim = reader.read_i32();
+       args.ws = reader.read_i32();
+       args.epoch = reader.read_i32();
+       args.minCount = reader.read_i32();
+       args.neg = reader.read_i32();
+       args.wordNgrams = reader.read_i32();
+       args.loss = reader.read_i32();
+       args.model = static_cast<model_name>(reader.read_i32());
+       args.bucket = reader.read_i32();
+       args.minn = reader.read_i32();
+       args.maxn = reader.read_i32();
+       args.lrUpdateRate = reader.read_i32();
+       args.t = reader.read_f64();
+
+       if (reader.fail()) {
+               return tl::make_unexpected(
+                       rspamd::util::error(
+                               fmt::format("truncated fasttext model header in '{}'", path),
+                               EINVAL));
+       }
 
-               /* Read output matrix - check if quantized */
-               if (!quant_input) {
-                       /* Dense output */
-                       impl->output_matrix = load_dense_matrix(reader, nullptr);
+       /* Read dictionary */
+       impl->dict.load(reader, args);
+
+       if (reader.fail()) {
+               return tl::make_unexpected(
+                       rspamd::util::error(
+                               fmt::format("truncated fasttext dictionary in '{}'", path),
+                               EINVAL));
+       }
+
+       /* Determine if input matrix is quantized */
+       auto quant_input = reader.read_bool();
+
+       if (quant_input) {
+               impl->input_matrix = load_quant_matrix(reader);
+       }
+       else {
+               /* Dense input matrix - pointer into mmap region (zero-copy) */
+               impl->input_matrix = load_dense_matrix(reader, base);
+       }
+
+       if (reader.fail()) {
+               return tl::make_unexpected(
+                       rspamd::util::error(
+                               fmt::format("truncated fasttext input matrix in '{}'", path),
+                               EINVAL));
+       }
+
+       /* Read output matrix - check if quantized */
+       if (!quant_input) {
+               /* Dense output */
+               impl->output_matrix = load_dense_matrix(reader, nullptr);
+       }
+       else {
+               auto quant_output = reader.read_bool();
+               if (quant_output) {
+                       impl->output_matrix = load_quant_matrix(reader);
                }
                else {
-                       auto quant_output = reader.read_bool();
-                       if (quant_output) {
-                               impl->output_matrix = load_quant_matrix(reader);
-                       }
-                       else {
-                               impl->output_matrix = load_dense_matrix(reader, nullptr);
-                       }
+                       impl->output_matrix = load_dense_matrix(reader, nullptr);
                }
+       }
 
-               /* Store the mmap to keep it alive */
-               impl->mmap_file.emplace(std::move(*mmap_result));
-
-               return fasttext_model(std::move(impl));
-       } catch (const std::exception &e) {
+       if (reader.fail()) {
                return tl::make_unexpected(
                        rspamd::util::error(
-                               fmt::format("failed to parse fasttext model '{}': {}", path, e.what()),
+                               fmt::format("truncated fasttext output matrix in '{}'", path),
                                EINVAL));
        }
+
+       /* Store the mmap to keep it alive */
+       impl->mmap_file.emplace(std::move(*mmap_result));
+
+       return fasttext_model(std::move(impl));
 }
 
 void fasttext_model::word2vec(std::string_view word, std::vector<std::int32_t> &ngrams) const