]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
ZSTD_decompress now handles multiple frames
authorSean Purcell <me@seanp.xyz>
Wed, 8 Feb 2017 00:16:55 +0000 (16:16 -0800)
committerSean Purcell <me@seanp.xyz>
Wed, 8 Feb 2017 22:50:10 +0000 (14:50 -0800)
lib/decompress/zstd_decompress.c
lib/legacy/zstd_v07.c
tests/fuzzer.c

index 00c0d937e5c94879819596cdc32250895cd82814..18ee070cfc4974da52a596f7efae92ea384030cd 100644 (file)
@@ -354,8 +354,7 @@ unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize)
                 skippableSize = MEM_readLE32((const BYTE *)src + 4) +
                                 ZSTD_skippableHeaderSize;
                 if (srcSize < skippableSize) {
-                    /* srcSize_wrong */
-                    return 0;
+                    return ZSTD_CONTENTSIZE_ERROR;
                 }
 
                 src = (const BYTE *)src + skippableSize;
@@ -384,7 +383,7 @@ unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize)
                     frameSrcSize = ZSTD_frameSrcSize(src, srcSize);
                 }
                 if (ZSTD_isError(frameSrcSize)) {
-                    return 0;
+                    return ZSTD_CONTENTSIZE_ERROR;
                 }
 
                 src = (const BYTE *)src + frameSrcSize;
@@ -393,8 +392,7 @@ unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize)
         }
 
         if (srcSize) {
-            /* srcSize_wrong */
-            return 0;
+            return ZSTD_CONTENTSIZE_ERROR;
         }
 
         return totalDstSize;
@@ -1498,22 +1496,22 @@ static size_t ZSTD_frameSrcSize(const void *src, size_t srcSize)
 *   `dctx` must be properly initialized */
 static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
                                  void* dst, size_t dstCapacity,
-                                 const void* src, size_t srcSize)
+                                 const void** srcPtr, size_t *srcSizePtr)
 {
-    const BYTE* ip = (const BYTE*)src;
+    const BYTE* ip = (const BYTE*)(*srcPtr);
     BYTE* const ostart = (BYTE* const)dst;
     BYTE* const oend = ostart + dstCapacity;
     BYTE* op = ostart;
-    size_t remainingSize = srcSize;
+    size_t remainingSize = *srcSizePtr;
 
     /* check */
-    if (srcSize < ZSTD_frameHeaderSize_min+ZSTD_blockHeaderSize) return ERROR(srcSize_wrong);
+    if (remainingSize < ZSTD_frameHeaderSize_min+ZSTD_blockHeaderSize) return ERROR(srcSize_wrong);
 
     /* Frame Header */
-    {   size_t const frameHeaderSize = ZSTD_frameHeaderSize(src, ZSTD_frameHeaderSize_prefix);
+    {   size_t const frameHeaderSize = ZSTD_frameHeaderSize(ip, ZSTD_frameHeaderSize_prefix);
         if (ZSTD_isError(frameHeaderSize)) return frameHeaderSize;
-        if (srcSize < frameHeaderSize+ZSTD_blockHeaderSize) return ERROR(srcSize_wrong);
-        CHECK_F(ZSTD_decodeFrameHeader(dctx, src, frameHeaderSize));
+        if (remainingSize < frameHeaderSize+ZSTD_blockHeaderSize) return ERROR(srcSize_wrong);
+        CHECK_F(ZSTD_decodeFrameHeader(dctx, ip, frameHeaderSize));
         ip += frameHeaderSize; remainingSize -= frameHeaderSize;
     }
 
@@ -1558,25 +1556,98 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
         if (remainingSize<4) return ERROR(checksum_wrong);
         checkRead = MEM_readLE32(ip);
         if (checkRead != checkCalc) return ERROR(checksum_wrong);
+        ip += 4;
         remainingSize -= 4;
     }
 
-    if (remainingSize) return ERROR(srcSize_wrong);
+    // Allow caller to get size read
+    *srcPtr = ip;
+    *srcSizePtr = remainingSize;
     return op-ostart;
 }
 
+static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx,
+                                        void* dst, size_t dstCapacity,
+                                  const void* src, size_t srcSize,
+                                  const void *dict, size_t dictSize,
+                                  const ZSTD_DCtx* refContext)
+{
+    void* const dststart = dst;
+    while (srcSize >= ZSTD_frameHeaderSize_prefix) {
+        U32 magicNumber;
+
+#if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT >= 1)
+        if (ZSTD_isLegacy(src, srcSize)) {
+            size_t const frameSize = ZSTD_frameSrcSizeLegacy(src, srcSize);
+            size_t decodedSize;
+            if (ZSTD_isError(frameSize)) return frameSize;
+
+            decodedSize = ZSTD_decompressLegacy(dst, dstCapacity, src, frameSize, dict, dictSize);
+
+            dst = (BYTE*)dst + decodedSize;
+            dstCapacity -= decodedSize;
+
+            src = (const BYTE*)src + frameSize;
+            srcSize -= frameSize;
+
+            continue;
+        }
+#endif
+
+        magicNumber = MEM_readLE32(src);
+        if (magicNumber != ZSTD_MAGICNUMBER) {
+            if ((magicNumber & 0xFFFFFFF0U) == ZSTD_MAGIC_SKIPPABLE_START) {
+                size_t skippableSize;
+                if (srcSize < ZSTD_skippableHeaderSize)
+                    return ERROR(srcSize_wrong);
+                skippableSize = MEM_readLE32((const BYTE *)src + 4) +
+                                ZSTD_skippableHeaderSize;
+                if (srcSize < skippableSize) {
+                    return ERROR(srcSize_wrong);
+                }
+
+                src = (const BYTE *)src + skippableSize;
+                srcSize -= skippableSize;
+                continue;
+            } else {
+                return ERROR(prefix_unknown);
+            }
+        }
+
+        if (refContext) {
+            /* we were called from ZSTD_decompress_usingDDict */
+            ZSTD_refDCtx(dctx, refContext);
+        } else {
+            /* this will initialize correctly with no dict if dict == NULL, so
+             * use this in all cases but ddict */
+            CHECK_F(ZSTD_decompressBegin_usingDict(dctx, dict, dictSize));
+        }
+        ZSTD_checkContinuity(dctx, dst);
+
+        {
+            const size_t res = ZSTD_decompressFrame(dctx, dst, dstCapacity,
+                                                    &src, &srcSize);
+            if (ZSTD_isError(res)) return res;
+            /* don't need to bounds check this, ZSTD_decompressFrame will have
+             * already */
+            dst = (BYTE*)dst + res;
+            dstCapacity -= res;
+        }
+    }
+
+    if (srcSize) {
+        return ERROR(srcSize_wrong);
+    }
+
+    return (BYTE*)dst - (BYTE*)dststart;
+}
 
 size_t ZSTD_decompress_usingDict(ZSTD_DCtx* dctx,
                                  void* dst, size_t dstCapacity,
                            const void* src, size_t srcSize,
                            const void* dict, size_t dictSize)
 {
-#if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT==1)
-    if (ZSTD_isLegacy(src, srcSize)) return ZSTD_decompressLegacy(dst, dstCapacity, src, srcSize, dict, dictSize);
-#endif
-    CHECK_F(ZSTD_decompressBegin_usingDict(dctx, dict, dictSize));
-    ZSTD_checkContinuity(dctx, dst);
-    return ZSTD_decompressFrame(dctx, dst, dstCapacity, src, srcSize);
+    return ZSTD_decompressMultiFrame(dctx, dst, dstCapacity, src, srcSize, dict, dictSize, NULL);
 }
 
 
@@ -1973,12 +2044,10 @@ size_t ZSTD_decompress_usingDDict(ZSTD_DCtx* dctx,
                             const void* src, size_t srcSize,
                             const ZSTD_DDict* ddict)
 {
-#if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT==1)
-    if (ZSTD_isLegacy(src, srcSize)) return ZSTD_decompressLegacy(dst, dstCapacity, src, srcSize, ddict->dictContent, ddict->dictSize);
-#endif
-    ZSTD_refDCtx(dctx, ddict->refContext);
-    ZSTD_checkContinuity(dctx, dst);
-    return ZSTD_decompressFrame(dctx, dst, dstCapacity, src, srcSize);
+    /* pass content and size in case legacy frames are encountered */
+    return ZSTD_decompressMultiFrame(dctx, dst, dstCapacity, src, srcSize,
+                                     ddict->dictContent, ddict->dictSize,
+                                     ddict->refContext);
 }
 
 
index 4ee055ddda5ceee1a967939cea97a69f3bcc446a..6228ee8707e64ee6c5f6cf89d8aed85945b2f3c8 100644 (file)
@@ -3992,6 +3992,9 @@ size_t ZSTDv07_frameSrcSize(const void* src, size_t srcSize)
 
         ip += ZSTDv07_blockHeaderSize;
         remainingSize -= ZSTDv07_blockHeaderSize;
+
+        if (blockProperties.blockType == bt_end) break;
+
         if (cBlockSize > remainingSize) return ERROR(srcSize_wrong);
 
         ip += cBlockSize;
index e7f25aae5691c62856efaa4b0f83d0eac2b5f117..4239af4558b88e853a20aa81b96ce06a242d0c50 100644 (file)
@@ -167,6 +167,46 @@ static int basicUnitTests(U32 seed, double compressibility)
       if (ZSTD_getErrorCode(r) != ZSTD_error_srcSize_wrong) goto _output_error; }
     DISPLAYLEVEL(4, "OK \n");
 
+    /* Simple API multiframe test */
+    DISPLAYLEVEL(4, "test%3i : compress multiple frames : ", testNb++);
+    {   size_t off = 0;
+        int i;
+        int const segs = 4;
+        /* only use the first half so we don't push against size limit of compressedBuffer */
+        size_t const segSize = (CNBuffSize / 2) / segs;
+        for (i = 0; i < segs; i++) {
+            CHECK_V(r,
+                    ZSTD_compress(
+                            (BYTE *)compressedBuffer + off, CNBuffSize - off,
+                            (BYTE *)CNBuffer + segSize * i,
+                            segSize, 5));
+            off += r;
+            if (i == segs/2) {
+                /* insert skippable frame */
+                const U32 skipLen = 128 KB;
+                MEM_writeLE32((BYTE*)compressedBuffer + off, ZSTD_MAGIC_SKIPPABLE_START);
+                MEM_writeLE32((BYTE*)compressedBuffer + off + 4, skipLen);
+                off += skipLen + ZSTD_skippableHeaderSize;
+            }
+        }
+        cSize = off;
+    }
+    DISPLAYLEVEL(4, "OK \n");
+
+    DISPLAYLEVEL(4, "test%3i : get decompressed size of multiple frames : ", testNb++);
+    {   unsigned long long const r = ZSTD_findDecompressedSize(compressedBuffer, cSize);
+        if (r != CNBuffSize / 2) goto _output_error; }
+    DISPLAYLEVEL(4, "OK \n");
+
+    DISPLAYLEVEL(4, "test%3i : decompress multiple frames : ", testNb++);
+    {   CHECK_V(r, ZSTD_decompress(decodedBuffer, CNBuffSize, compressedBuffer, cSize));
+        if (r != CNBuffSize / 2) goto _output_error; }
+    DISPLAYLEVEL(4, "OK \n");
+
+    DISPLAYLEVEL(4, "test%3i : check decompressed result : ", testNb++);
+    if (memcmp(decodedBuffer, CNBuffer, CNBuffSize / 2) != 0) goto _output_error;
+    DISPLAYLEVEL(4, "OK \n");
+
     /* Dictionary and CCtx Duplication tests */
     {   ZSTD_CCtx* const ctxOrig = ZSTD_createCCtx();
         ZSTD_CCtx* const ctxDuplicated = ZSTD_createCCtx();