]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
use ZSTD_decodingBufferSize_min() inside ZSTD_decompressStream()
authorYann Collet <cyan@fb.com>
Sat, 9 Sep 2017 21:37:28 +0000 (14:37 -0700)
committerYann Collet <cyan@fb.com>
Sat, 9 Sep 2017 21:37:28 +0000 (14:37 -0700)
Use same definition as public one
minor : reduce allocated buffer size in some cases
(when frameContentSize is known and == windowSize)

lib/decompress/zstd_decompress.c
tests/zstreamtest.c

index aeac17d1928364a8dd746b093923c19dac481ef8..5158d3f32a8c5d29c043177888100e0663771cb7 100644 (file)
@@ -102,7 +102,8 @@ struct ZSTD_DCtx_s
     const void* dictEnd;          /* end of previous segment */
     size_t expected;
     ZSTD_frameHeader fParams;
-    blockType_e bType;   /* used in ZSTD_decompressContinue(), to transfer blockType between header decoding and block decoding stages */
+    U64 decodedSize;
+    blockType_e bType;            /* used in ZSTD_decompressContinue(), store blockType between block header decoding and block decompression stages */
     ZSTD_dStage stage;
     U32 litEntropy;
     U32 fseEntropy;
@@ -127,7 +128,6 @@ struct ZSTD_DCtx_s
     size_t outBuffSize;
     size_t outStart;
     size_t outEnd;
-    size_t blockSize;
     size_t lhSize;
     void* legacyContext;
     U32 previousLegacyVersion;
@@ -153,6 +153,7 @@ size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx)
 {
     dctx->expected = ZSTD_frameHeaderSize_prefix;
     dctx->stage = ZSTDds_getFrameHeaderSize;
+    dctx->decodedSize = 0;
     dctx->previousDstEnd = NULL;
     dctx->base = NULL;
     dctx->vBase = NULL;
@@ -172,13 +173,13 @@ size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx)
 static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx)
 {
     ZSTD_decompressBegin(dctx);   /* cannot fail */
-    dctx->staticSize = 0;
+    dctx->staticSize  = 0;
     dctx->maxWindowSize = ZSTD_MAXWINDOWSIZE_DEFAULT;
-    dctx->ddict   = NULL;
-    dctx->ddictLocal = NULL;
-    dctx->inBuff  = NULL;
-    dctx->inBuffSize = 0;
-    dctx->outBuffSize= 0;
+    dctx->ddict       = NULL;
+    dctx->ddictLocal  = NULL;
+    dctx->inBuff      = NULL;
+    dctx->inBuffSize  = 0;
+    dctx->outBuffSize = 0;
     dctx->streamStage = zdss_init;
 }
 
@@ -1771,9 +1772,16 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c
                 return ERROR(corruption_detected);
             }
             if (ZSTD_isError(rSize)) return rSize;
+            DEBUGLOG(5, "decoded size from block : %u", (U32)rSize);
+            dctx->decodedSize += rSize;
             if (dctx->fParams.checksumFlag) XXH64_update(&dctx->xxhState, dst, rSize);
 
             if (dctx->stage == ZSTDds_decompressLastBlock) {   /* end of frame */
+                DEBUGLOG(4, "decoded size from frame : %u", (U32)dctx->decodedSize);
+                if (dctx->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN) {
+                    if (dctx->decodedSize != dctx->fParams.frameContentSize) {
+                        return ERROR(corruption_detected);
+                }   }
                 if (dctx->fParams.checksumFlag) {  /* another round for frame checksum */
                     dctx->expected = 4;
                     dctx->stage = ZSTDds_checkChecksum;
@@ -1789,8 +1797,11 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c
             return rSize;
         }
     case ZSTDds_checkChecksum:
+        DEBUGLOG(4, "case ZSTDds_checkChecksum");
+        assert(srcSize == 4);  /* guaranteed by dctx->expected */
         {   U32 const h32 = (U32)XXH64_digest(&dctx->xxhState);
-            U32 const check32 = MEM_readLE32(src);   /* srcSize == 4, guaranteed by dctx->expected */
+            U32 const check32 = MEM_readLE32(src);
+            DEBUGLOG(4, "calculated %08X :: %08X read", h32, check32);
             if (check32 != h32) return ERROR(checksum_wrong);
             dctx->expected = 0;
             dctx->stage = ZSTDds_getFrameHeaderSize;
@@ -2361,15 +2372,14 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
             if (zds->fParams.windowSize > zds->maxWindowSize) return ERROR(frameParameter_windowTooLarge);
 
             /* Adapt buffer sizes to frame header instructions */
-            {   size_t const blockSize = zds->fParams.blockSizeMax;
-                size_t const neededOutSize = (size_t)(zds->fParams.windowSize + blockSize + WILDCOPY_OVERLENGTH * 2);
-                zds->blockSize = blockSize;
-                if ((zds->inBuffSize < blockSize) || (zds->outBuffSize < neededOutSize)) {
-                    size_t const bufferSize = blockSize + neededOutSize;
+            {   size_t const neededInBuffSize = MAX(zds->fParams.blockSizeMax, 4 /* frame checksum */);
+                size_t const neededOutBuffSize = ZSTD_decodingBufferSize_min(zds->fParams.windowSize, zds->fParams.frameContentSize);
+                if ((zds->inBuffSize < neededInBuffSize) || (zds->outBuffSize < neededOutBuffSize)) {
+                    size_t const bufferSize = neededInBuffSize + neededOutBuffSize;
                     DEBUGLOG(4, "inBuff  : from %u to %u",
-                                (U32)zds->inBuffSize, (U32)blockSize);
+                                (U32)zds->inBuffSize, (U32)neededInBuffSize);
                     DEBUGLOG(4, "outBuff : from %u to %u",
-                                (U32)zds->outBuffSize, (U32)neededOutSize);
+                                (U32)zds->outBuffSize, (U32)neededOutBuffSize);
                     if (zds->staticSize) {  /* static DCtx */
                         DEBUGLOG(4, "staticSize : %u", (U32)zds->staticSize);
                         assert(zds->staticSize >= sizeof(ZSTD_DCtx));  /* controlled at init */
@@ -2382,9 +2392,9 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
                         zds->inBuff = (char*)ZSTD_malloc(bufferSize, zds->customMem);
                         if (zds->inBuff == NULL) return ERROR(memory_allocation);
                     }
-                    zds->inBuffSize = blockSize;
+                    zds->inBuffSize = neededInBuffSize;
                     zds->outBuff = zds->inBuff + zds->inBuffSize;
-                    zds->outBuffSize = neededOutSize;
+                    zds->outBuffSize = neededOutBuffSize;
             }   }
             zds->streamStage = zdss_read;
             /* fall-through */
@@ -2442,8 +2452,13 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
                 zds->outStart += flushedSize;
                 if (flushedSize == toFlushSize) {  /* flush completed */
                     zds->streamStage = zdss_read;
-                    if (zds->outStart + zds->blockSize > zds->outBuffSize)
+                    if ( (zds->outBuffSize < zds->fParams.frameContentSize)
+                      && (zds->outStart + zds->fParams.blockSizeMax > zds->outBuffSize) ) {
+                        DEBUGLOG(5, "restart filling outBuff from beginning (left:%i, needed:%u)",
+                                (int)(zds->outBuffSize - zds->outStart),
+                                (U32)zds->fParams.blockSizeMax);
                         zds->outStart = zds->outEnd = 0;
+                    }
                     break;
             }   }
             /* cannot complete flush */
index d7b2e197a9a510022edbc42710f68318b1f55f71..8c8adc62d4deb4ef65e223168e94a5f6724a3ddc 100644 (file)
@@ -909,10 +909,16 @@ static int fuzzerTests(U32 seed, U32 nbTests, unsigned startTest, double compres
                 inBuff.size = inBuff.pos + readCSrcSize;
                 outBuff.size = inBuff.pos + dstBuffSize;
                 decompressionResult = ZSTD_decompressStream(zd, &outBuff, &inBuff);
-                CHECK (ZSTD_isError(decompressionResult), "decompression error : %s", ZSTD_getErrorName(decompressionResult));
+                if (ZSTD_getErrorCode(decompressionResult) == ZSTD_error_checksum_wrong) {
+                    DISPLAY("checksum error : \n");
+                    findDiff(copyBuffer, dstBuffer, totalTestSize);
+                }
+                CHECK( ZSTD_isError(decompressionResult), "decompression error : %s",
+                       ZSTD_getErrorName(decompressionResult) );
             }
             CHECK (decompressionResult != 0, "frame not fully decoded");
-            CHECK (outBuff.pos != totalTestSize, "decompressed data : wrong size")
+            CHECK (outBuff.pos != totalTestSize, "decompressed data : wrong size (%u != %u)",
+                    (U32)outBuff.pos, (U32)totalTestSize);
             CHECK (inBuff.pos != cSize, "compressed data should be fully read")
             {   U64 const crcDest = XXH64(dstBuffer, totalTestSize, 0);
                 if (crcDest!=crcOrig) findDiff(copyBuffer, dstBuffer, totalTestSize);