#include <string>
#include <string_view>
+#ifdef SYS_ZSTD
+#include "zstd.h"
+#else
+#include "contrib/zstd/zstd.h"
+#endif
+
TEST_SUITE("multipart_form")
{
TEST_CASE("basic two-part form")
}
}
+TEST_SUITE("multipart_zstd")
+{
+ TEST_CASE("serialize with compression produces Content-Encoding: zstd")
+ {
+ rspamd::http::multipart_response resp;
+ std::string data = "{\"action\":\"reject\",\"score\":15.0}";
+ resp.add_part("result", "application/json", data, true /* compress */);
+
+ ZSTD_CStream *cstream = ZSTD_createCStream();
+ ZSTD_initCStream(cstream, 1);
+
+ auto serialized = resp.serialize(cstream);
+ ZSTD_freeCStream(cstream);
+
+ /* Parse the multipart output */
+ auto boundary = std::string(resp.get_boundary());
+ auto parsed = rspamd::http::parse_multipart_form(serialized, boundary);
+ REQUIRE(parsed.has_value());
+ CHECK(parsed->parts.size() == 1);
+ CHECK(parsed->parts[0].content_encoding == "zstd");
+
+ /* The data should be compressed (not matching original) */
+ CHECK(parsed->parts[0].data != data);
+
+ /* Decompress and verify */
+ ZSTD_DStream *dstream = ZSTD_createDStream();
+ ZSTD_initDStream(dstream);
+ auto &compressed = parsed->parts[0].data;
+ ZSTD_inBuffer zin = {compressed.data(), compressed.size(), 0};
+ std::string decompressed(data.size() * 2, '\0');
+ ZSTD_outBuffer zout = {decompressed.data(), decompressed.size(), 0};
+
+ while (zin.pos < zin.size) {
+ size_t r = ZSTD_decompressStream(dstream, &zout, &zin);
+ REQUIRE(!ZSTD_isError(r));
+ }
+ ZSTD_freeDStream(dstream);
+ decompressed.resize(zout.pos);
+
+ CHECK(decompressed == data);
+ }
+
+ TEST_CASE("prepare_iov with compression round-trip")
+ {
+ rspamd::http::multipart_response resp;
+ std::string result = "{\"score\":42}";
+ std::string body = "Hello world body data for compression test";
+ resp.add_part("result", "application/json", result, true);
+ resp.add_part("body", "application/octet-stream", body, true);
+
+ ZSTD_CStream *cs = ZSTD_createCStream();
+ ZSTD_initCStream(cs, 1);
+ resp.prepare_iov(cs);
+ ZSTD_freeCStream(cs);
+
+ /* Reassemble iov */
+ std::string reassembled;
+ for (gsize i = 0; i < resp.body_iov_count(); i++) {
+ const auto *iov = &resp.body_iov()[i];
+ reassembled.append(static_cast<const char *>(iov->iov_base), iov->iov_len);
+ }
+ CHECK(reassembled.size() == resp.body_total_len());
+
+ /* Parse and verify both parts have zstd encoding */
+ auto boundary = std::string(resp.get_boundary());
+ auto parsed = rspamd::http::parse_multipart_form(reassembled, boundary);
+ REQUIRE(parsed.has_value());
+ CHECK(parsed->parts.size() == 2);
+ CHECK(parsed->parts[0].content_encoding == "zstd");
+ CHECK(parsed->parts[1].content_encoding == "zstd");
+
+ /* Decompress result part */
+ {
+ ZSTD_DStream *ds = ZSTD_createDStream();
+ ZSTD_initDStream(ds);
+ auto &comp = parsed->parts[0].data;
+ ZSTD_inBuffer zin = {comp.data(), comp.size(), 0};
+ std::string dec(result.size() * 4, '\0');
+ ZSTD_outBuffer zout = {dec.data(), dec.size(), 0};
+ while (zin.pos < zin.size) {
+ size_t r = ZSTD_decompressStream(ds, &zout, &zin);
+ REQUIRE(!ZSTD_isError(r));
+ }
+ ZSTD_freeDStream(ds);
+ dec.resize(zout.pos);
+ CHECK(dec == result);
+ }
+
+ /* Decompress body part */
+ {
+ ZSTD_DStream *ds = ZSTD_createDStream();
+ ZSTD_initDStream(ds);
+ auto &comp = parsed->parts[1].data;
+ ZSTD_inBuffer zin = {comp.data(), comp.size(), 0};
+ std::string dec(body.size() * 4, '\0');
+ ZSTD_outBuffer zout = {dec.data(), dec.size(), 0};
+ while (zin.pos < zin.size) {
+ size_t r = ZSTD_decompressStream(ds, &zout, &zin);
+ REQUIRE(!ZSTD_isError(r));
+ }
+ ZSTD_freeDStream(ds);
+ dec.resize(zout.pos);
+ CHECK(dec == body);
+ }
+ }
+
+ TEST_CASE("mixed compressed and uncompressed parts")
+ {
+ rspamd::http::multipart_response resp;
+ std::string result = "{\"action\":\"no action\"}";
+ std::string body = "Plain uncompressed body";
+ resp.add_part("result", "application/json", result, true); /* compressed */
+ resp.add_part("body", "application/octet-stream", body, false); /* uncompressed */
+
+ ZSTD_CStream *cs = ZSTD_createCStream();
+ ZSTD_initCStream(cs, 1);
+ resp.prepare_iov(cs);
+ ZSTD_freeCStream(cs);
+
+ std::string reassembled;
+ for (gsize i = 0; i < resp.body_iov_count(); i++) {
+ const auto *iov = &resp.body_iov()[i];
+ reassembled.append(static_cast<const char *>(iov->iov_base), iov->iov_len);
+ }
+
+ auto boundary = std::string(resp.get_boundary());
+ auto parsed = rspamd::http::parse_multipart_form(reassembled, boundary);
+ REQUIRE(parsed.has_value());
+ CHECK(parsed->parts.size() == 2);
+
+ /* Result part: compressed */
+ CHECK(parsed->parts[0].content_encoding == "zstd");
+
+ /* Body part: uncompressed — data should match directly */
+ CHECK(parsed->parts[1].content_encoding.empty());
+ CHECK(parsed->parts[1].data == body);
+ }
+
+ TEST_CASE("body_iov segments are writable for in-place encryption")
+ {
+ /* The encryption path (rspamd_cryptobox_encryptv_nm_inplace) writes
+ * to body_iov segments in-place. Verify all segments are writable. */
+ rspamd::http::multipart_response resp;
+ std::string result = "{\"action\":\"reject\"}";
+ std::string body = "Message body content here";
+ resp.add_part("result", "application/json", result, true);
+ resp.add_part("body", "application/octet-stream", body, false);
+
+ ZSTD_CStream *cs = ZSTD_createCStream();
+ ZSTD_initCStream(cs, 1);
+ resp.prepare_iov(cs);
+ ZSTD_freeCStream(cs);
+
+ /* Verify we can write to every byte of every iov segment
+ * (simulates what encryptv_nm_inplace does) */
+ for (gsize i = 0; i < resp.body_iov_count(); i++) {
+ auto *iov = &resp.body_iov()[i];
+ auto *p = static_cast<unsigned char *>(iov->iov_base);
+ for (gsize j = 0; j < iov->iov_len; j++) {
+ p[j] ^= 0xFF; /* XOR (simulate encryption) */
+ }
+ }
+
+ /* XOR back to restore */
+ for (gsize i = 0; i < resp.body_iov_count(); i++) {
+ auto *iov = &resp.body_iov()[i];
+ auto *p = static_cast<unsigned char *>(iov->iov_base);
+ for (gsize j = 0; j < iov->iov_len; j++) {
+ p[j] ^= 0xFF;
+ }
+ }
+
+ /* After restoring, reassemble and verify it parses correctly */
+ std::string reassembled;
+ for (gsize i = 0; i < resp.body_iov_count(); i++) {
+ const auto *iov = &resp.body_iov()[i];
+ reassembled.append(static_cast<const char *>(iov->iov_base), iov->iov_len);
+ }
+
+ auto boundary = std::string(resp.get_boundary());
+ auto parsed = rspamd::http::parse_multipart_form(reassembled, boundary);
+ REQUIRE(parsed.has_value());
+ CHECK(parsed->parts.size() == 2);
+ }
+}
+
#endif// RSPAMD_CXX_UNIT_MULTIPART_HXX