From: shakeelrao Date: Thu, 28 Feb 2019 08:42:49 +0000 (-0800) Subject: Provide an API function to estimate decompressed size. X-Git-Tag: v1.4.0^2~30^2~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=820af1e07855019c95299e27666321fc8b096ebd;p=thirdparty%2Fzstd.git Provide an API function to estimate decompressed size. Introduces a new utility function `ZSTD_findFrameCompressedSize_internal` which is equivalent to `ZSTD_findFrameCompressSize`, but accepts an additional output parameter `bound` that computes an upper-bound for the compressed data in the frame. The new API function is named `ZSTD_decompressBound` to be consistent with `zstd_compressBound` (the inverse operation). Clients will now be able to compute an upper-bound for their compressed payloads instead of guessing a large size. Implements https://github.com/facebook/zstd/issues/1536. --- diff --git a/doc/zstd_manual.html b/doc/zstd_manual.html index c7962e7de..a11a95892 100644 --- a/doc/zstd_manual.html +++ b/doc/zstd_manual.html @@ -127,10 +127,11 @@ unsigned long long ZSTD_getFrameContentSize(const void *src, size_t srcSize);


Helper functions

#define ZSTD_COMPRESSBOUND(srcSize)   ((srcSize) + ((srcSize)>>8) + (((srcSize) < (128<<10)) ? (((128<<10) - (srcSize)) >> 11) /* margin, from 64 to 0 */ : 0))  /* this formula ensures that bound(A) + bound(B) <= bound(A+B) as long as A and B >= 128 KB */
-size_t      ZSTD_compressBound(size_t srcSize); /*!< maximum compressed size in worst case single-pass scenario */
-unsigned    ZSTD_isError(size_t code);          /*!< tells if a `size_t` function result is an error code */
-const char* ZSTD_getErrorName(size_t code);     /*!< provides readable string from an error code */
-int         ZSTD_maxCLevel(void);               /*!< maximum compression level available */
+size_t      ZSTD_compressBound(size_t srcSize);                    /*!< maximum compressed size in worst case single-pass scenario */
+size_t      ZSTD_decompressBound(const void* src, size_t srcSice); /*!< maximum decompressed size of the compressed source */
+unsigned    ZSTD_isError(size_t code);                             /*!< tells if a `size_t` function result is an error code */
+const char* ZSTD_getErrorName(size_t code);                        /*!< provides readable string from an error code */
+int         ZSTD_maxCLevel(void);                                  /*!< maximum compression level available */
 

Explicit context


 
diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c
index 601bfe704..a29122a83 100644
--- a/lib/decompress/zstd_decompress.c
+++ b/lib/decompress/zstd_decompress.c
@@ -434,12 +434,7 @@ static size_t ZSTD_decodeFrameHeader(ZSTD_DCtx* dctx, const void* src, size_t he
 }
 
 
-/** ZSTD_findFrameCompressedSize() :
- *  compatible with legacy mode
- *  `src` must point to the start of a ZSTD frame, ZSTD legacy frame, or skippable frame
- *  `srcSize` must be at least as large as the frame contained
- *  @return : the compressed size of the frame starting at `src` */
-size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize)
+static size_t ZSTD_findFrameCompressedSize_internal(const void *src, size_t srcSize, size_t* bound)
 {
 #if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT >= 1)
     if (ZSTD_isLegacy(src, srcSize))
@@ -464,6 +459,7 @@ size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize)
         remainingSize -= zfh.headerSize;
 
         /* Loop on each block */
+        unsigned nbBlocks = 0;
         while (1) {
             blockProperties_t blockProperties;
             size_t const cBlockSize = ZSTD_getcBlockSize(ip, remainingSize, &blockProperties);
@@ -474,6 +470,7 @@ size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize)
 
             ip += ZSTD_blockHeaderSize + cBlockSize;
             remainingSize -= ZSTD_blockHeaderSize + cBlockSize;
+            nbBlocks++;
 
             if (blockProperties.lastBlock) break;
         }
@@ -483,10 +480,78 @@ size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize)
             ip += 4;
         }
 
+        if (bound != NULL) *bound = (nbBlocks * zfh.blockSizeMax); /* set to block-based bound */
+
         return ip - ipstart;
     }
 }
 
+/** ZSTD_findFrameCompressedSize() :
+ *  compatible with legacy mode
+ *  `src` must point to the start of a ZSTD frame, ZSTD legacy frame, or skippable frame
+ *  `srcSize` must be at least as large as the frame contained
+ *  @return : the compressed size of the frame starting at `src` */
+size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize)
+{
+    return ZSTD_findFrameCompressedSize_internal(src, srcSize, NULL);
+}
+
+
+/** ZSTD_decompressBound() :
+ *  currently incompatible with legacy mode
+ *  `src` must point to the start of a ZSTD frame or a skippeable frame
+ *  `srcSize` must be at least as large as the frame contained
+ *  @return : maximum decompressed size of the compressed source
+ */
+size_t ZSTD_decompressBound(const void* src, size_t srcSize)
+{
+#if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT >= 1)
+    if (ZSTD_isLegacy(src, srcSize))
+        return ERROR(version_unsupported);
+#endif
+
+    size_t totalDstSize = 0;
+
+    /* Loop over each frame */
+    while (srcSize >= ZSTD_FRAMEHEADERSIZE_PREFIX) {
+        U32 const magicNumber = MEM_readLE32(src);
+
+        if ((magicNumber & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
+            size_t const skippableSize = readSkippableFrameSize(src, srcSize);
+            if (ZSTD_isError(skippableSize))
+                return skippableSize;
+            if (srcSize < skippableSize) {
+                return ZSTD_CONTENTSIZE_ERROR;
+            }
+
+            src = (const BYTE *)src + skippableSize;
+            srcSize -= skippableSize;
+            continue;
+        }
+
+        {   unsigned long long const ret = ZSTD_getFrameContentSize(src, srcSize);
+            if (ret == ZSTD_CONTENTSIZE_ERROR) return ret;
+
+            size_t bound;
+            size_t const frameSrcSize = ZSTD_findFrameCompressedSize_internal(src, srcSize, &bound);
+            if (ZSTD_isError(frameSrcSize)) {
+                return ZSTD_CONTENTSIZE_ERROR;
+            }
+
+            size_t frameBound = (ret == ZSTD_CONTENTSIZE_UNKNOWN) ? bound : ret;
+            /* check for overflow */
+            if (totalDstSize + frameBound < totalDstSize) return ZSTD_CONTENTSIZE_ERROR;
+            totalDstSize += frameBound;
+
+            src = (const BYTE *)src + frameSrcSize;
+            srcSize -= frameSrcSize;
+        }
+    }  /* while (srcSize >= ZSTD_frameHeaderSize_prefix) */
+
+    if (srcSize) return ZSTD_CONTENTSIZE_ERROR;
+
+    return totalDstSize;
+}
 
 
 /*-*************************************************************
diff --git a/lib/zstd.h b/lib/zstd.h
index 98020383f..313411017 100644
--- a/lib/zstd.h
+++ b/lib/zstd.h
@@ -148,10 +148,11 @@ ZSTDLIB_API unsigned long long ZSTD_getDecompressedSize(const void* src, size_t
 
 /*======  Helper functions  ======*/
 #define ZSTD_COMPRESSBOUND(srcSize)   ((srcSize) + ((srcSize)>>8) + (((srcSize) < (128<<10)) ? (((128<<10) - (srcSize)) >> 11) /* margin, from 64 to 0 */ : 0))  /* this formula ensures that bound(A) + bound(B) <= bound(A+B) as long as A and B >= 128 KB */
-ZSTDLIB_API size_t      ZSTD_compressBound(size_t srcSize); /*!< maximum compressed size in worst case single-pass scenario */
-ZSTDLIB_API unsigned    ZSTD_isError(size_t code);          /*!< tells if a `size_t` function result is an error code */
-ZSTDLIB_API const char* ZSTD_getErrorName(size_t code);     /*!< provides readable string from an error code */
-ZSTDLIB_API int         ZSTD_maxCLevel(void);               /*!< maximum compression level available */
+ZSTDLIB_API size_t      ZSTD_compressBound(size_t srcSize);                    /*!< maximum compressed size in worst case single-pass scenario */
+ZSTDLIB_API size_t      ZSTD_decompressBound(const void* src, size_t srcSice); /*!< maximum decompressed size of the compressed source */
+ZSTDLIB_API unsigned    ZSTD_isError(size_t code);                             /*!< tells if a `size_t` function result is an error code */
+ZSTDLIB_API const char* ZSTD_getErrorName(size_t code);                        /*!< provides readable string from an error code */
+ZSTDLIB_API int         ZSTD_maxCLevel(void);                                  /*!< maximum compression level available */
 
 
 /***************************************
diff --git a/tests/fuzzer.c b/tests/fuzzer.c
index 9aed11e38..00e7f5442 100644
--- a/tests/fuzzer.c
+++ b/tests/fuzzer.c
@@ -376,6 +376,20 @@ static int basicUnitTests(U32 seed, double compressibility)
     }
     DISPLAYLEVEL(3, "OK \n");
 
+    DISPLAYLEVEL(3, "test%3i : tight ZSTD_decompressBound test : ", testNb++);
+    {
+        size_t bound = ZSTD_decompressBound(compressedBuffer, cSize);
+        if (bound != CNBuffSize) goto _output_error;
+    }
+    DISPLAYLEVEL(3, "OK \n");
+
+    DISPLAYLEVEL(3, "test%3i : ZSTD_decompressBound test missing Frame_Content_Size : ", testNb++);
+    {
+        size_t bound = ZSTD_decompressBound(compressedBuffer, cSize);
+        if (bound != CNBuffSize) goto _output_error;
+    }
+    DISPLAYLEVEL(3, "OK \n");
+
     DISPLAYLEVEL(3, "test%3i : decompress %u bytes : ", testNb++, (unsigned)CNBuffSize);
     { size_t const r = ZSTD_decompress(decodedBuffer, CNBuffSize, compressedBuffer, cSize);
       if (r != CNBuffSize) goto _output_error; }
@@ -901,6 +915,11 @@ static int basicUnitTests(U32 seed, double compressibility)
         if (r != CNBuffSize / 2) goto _output_error; }
     DISPLAYLEVEL(3, "OK \n");
 
+    DISPLAYLEVEL(3, "test%3i : get tight decompressed bound of multiple frames : ", testNb++);
+    {   unsigned long long const r = ZSTD_decompressBound(compressedBuffer, cSize);
+        if (r != CNBuffSize / 2) goto _output_error; }
+    DISPLAYLEVEL(3, "OK \n");
+
     DISPLAYLEVEL(3, "test%3i : decompress multiple frames : ", testNb++);
     {   CHECK_V(r, ZSTD_decompress(decodedBuffer, CNBuffSize, compressedBuffer, cSize));
         if (r != CNBuffSize / 2) goto _output_error; }
diff --git a/tests/symbols.c b/tests/symbols.c
index 600d81670..4d9c6fc0c 100644
--- a/tests/symbols.c
+++ b/tests/symbols.c
@@ -31,6 +31,7 @@ static const void *symbols[] = {
   &ZSTD_getFrameContentSize,
   &ZSTD_maxCLevel,
   &ZSTD_compressBound,
+  &ZSTD_decompressBound,
   &ZSTD_isError,
   &ZSTD_getErrorName,
   &ZSTD_createCCtx,