]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Implement one-shot fallback for magicless format (#3971)
authorElliot Gorokhovsky <embg@fb.com>
Mon, 18 Mar 2024 14:55:53 +0000 (10:55 -0400)
committerGitHub <noreply@github.com>
Mon, 18 Mar 2024 14:55:53 +0000 (10:55 -0400)
lib/decompress/zstd_decompress.c
tests/zstreamtest.c

index f6579743859daf123da3a9b61b312ddc8fc1ffb6..ee2cda3b6390a636a7814734fecc637e2700fa64 100644 (file)
@@ -729,17 +729,17 @@ static ZSTD_frameSizeInfo ZSTD_errorFrameSizeInfo(size_t ret)
     return frameSizeInfo;
 }
 
-static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize)
+static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize, ZSTD_format_e format)
 {
     ZSTD_frameSizeInfo frameSizeInfo;
     ZSTD_memset(&frameSizeInfo, 0, sizeof(ZSTD_frameSizeInfo));
 
 #if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT >= 1)
-    if (ZSTD_isLegacy(src, srcSize))
+    if (format == ZSTD_f_zstd1 && ZSTD_isLegacy(src, srcSize))
         return ZSTD_findFrameSizeInfoLegacy(src, srcSize);
 #endif
 
-    if ((srcSize >= ZSTD_SKIPPABLEHEADERSIZE)
+    if (format == ZSTD_f_zstd1 && (srcSize >= ZSTD_SKIPPABLEHEADERSIZE)
         && (MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
         frameSizeInfo.compressedSize = readSkippableFrameSize(src, srcSize);
         assert(ZSTD_isError(frameSizeInfo.compressedSize) ||
@@ -753,7 +753,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize
         ZSTD_frameHeader zfh;
 
         /* Extract Frame Header */
-        {   size_t const ret = ZSTD_getFrameHeader(&zfh, src, srcSize);
+        {   size_t const ret = ZSTD_getFrameHeader_advanced(&zfh, src, srcSize, format);
             if (ZSTD_isError(ret))
                 return ZSTD_errorFrameSizeInfo(ret);
             if (ret > 0)
@@ -796,13 +796,17 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize
     }
 }
 
+static size_t ZSTD_findFrameCompressedSize_advanced(const void *src, size_t srcSize, ZSTD_format_e format) {
+    ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, format);
+    return frameSizeInfo.compressedSize;
+}
+
 /** ZSTD_findFrameCompressedSize() :
  * See docs in zstd.h
  * Note: compatible with legacy mode */
 size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize)
 {
-    ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
-    return frameSizeInfo.compressedSize;
+    return ZSTD_findFrameCompressedSize_advanced(src, srcSize, ZSTD_f_zstd1);
 }
 
 /** ZSTD_decompressBound() :
@@ -816,7 +820,7 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize)
     unsigned long long bound = 0;
     /* Iterate over each frame */
     while (srcSize > 0) {
-        ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
+        ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1);
         size_t const compressedSize = frameSizeInfo.compressedSize;
         unsigned long long const decompressedBound = frameSizeInfo.decompressedBound;
         if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR)
@@ -836,7 +840,7 @@ size_t ZSTD_decompressionMargin(void const* src, size_t srcSize)
 
     /* Iterate over each frame */
     while (srcSize > 0) {
-        ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
+        ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1);
         size_t const compressedSize = frameSizeInfo.compressedSize;
         unsigned long long const decompressedBound = frameSizeInfo.decompressedBound;
         ZSTD_frameHeader zfh;
@@ -2178,7 +2182,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
             if (zds->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN
                 && zds->fParams.frameType != ZSTD_skippableFrame
                 && (U64)(size_t)(oend-op) >= zds->fParams.frameContentSize) {
-                size_t const cSize = ZSTD_findFrameCompressedSize(istart, (size_t)(iend-istart));
+                size_t const cSize = ZSTD_findFrameCompressedSize_advanced(istart, (size_t)(iend-istart), zds->format);
                 if (cSize <= (size_t)(iend-istart)) {
                     /* shortcut : using single-pass mode */
                     size_t const decompressedSize = ZSTD_decompress_usingDDict(zds, op, (size_t)(oend-op), istart, cSize, ZSTD_getDDict(zds));
index 7cc4068bc094661d1ff8cb3e87b19b7dbfbae980..e0ee4c3e934f564d3ea047585ff7b7ffb08b5049 100644 (file)
@@ -2417,6 +2417,41 @@ static int basicUnitTests(U32 seed, double compressibility, int bigTests)
     }
     DISPLAYLEVEL(3, "OK \n");
 
+    DISPLAYLEVEL(3, "test%3i : Test single-shot fallback for magicless mode: ", testNb++);
+    {
+        // Aquire resources
+        size_t const srcSize = COMPRESSIBLE_NOISE_LENGTH;
+        void* src = malloc(srcSize);
+        size_t const dstSize = ZSTD_compressBound(srcSize);
+        void* dst = malloc(dstSize);
+        size_t const valSize = srcSize;
+        void* val = malloc(valSize);
+        ZSTD_inBuffer inBuf = { dst, dstSize, 0 };
+        ZSTD_outBuffer outBuf = { val, valSize, 0 };
+        ZSTD_CCtx* cctx = ZSTD_createCCtx();
+        ZSTD_DCtx* dctx = ZSTD_createDCtx();
+        CHECK(!src || !dst || !val || !dctx || !cctx, "memory allocation failure");
+
+        // Write test data for decompression to dst
+        RDG_genBuffer(src, srcSize, compressibility, 0.0, 0xdeadbeef);
+        CHECK_Z(ZSTD_CCtx_setParameter(cctx, ZSTD_c_format, ZSTD_f_zstd1_magicless));
+        CHECK_Z(ZSTD_compress2(cctx, dst, dstSize, src, srcSize));
+
+        // Run decompression
+        CHECK_Z(ZSTD_DCtx_setParameter(dctx, ZSTD_d_format, ZSTD_f_zstd1_magicless));
+        CHECK_Z(ZSTD_decompressStream(dctx, &outBuf, &inBuf));
+
+        // Validate
+        CHECK(outBuf.pos != srcSize, "decompressed size must match");
+        CHECK(memcmp(src, val, srcSize) != 0, "decompressed data must match");
+        
+        // Cleanup
+        free(src); free(dst); free(val);
+        ZSTD_freeCCtx(cctx);
+        ZSTD_freeDCtx(dctx);
+    }
+    DISPLAYLEVEL(3, "OK \n");
+
 _end:
     FUZ_freeDictionary(dictionary);
     ZSTD_freeCStream(zc);