]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
fix extended case combining stableInBuffer with continue() and flush() modes
authorYann Collet <cyan@fb.com>
Tue, 25 Jan 2022 06:57:55 +0000 (22:57 -0800)
committerYann Collet <cyan@fb.com>
Wed, 26 Jan 2022 18:31:25 +0000 (10:31 -0800)
lib/compress/zstd_compress.c
lib/compress/zstd_compress_internal.h
tests/fuzzer.c
tests/zstreamtest.c

index 7be5d87a7515f2a3568971fc1a786e1c3ca1b79c..e2a6d9675c549dc82898dac08018d01023efa6b7 100644 (file)
@@ -5316,9 +5316,14 @@ size_t ZSTD_initCStream(ZSTD_CStream* zcs, int compressionLevel)
 
 static size_t ZSTD_nextInputSizeHint(const ZSTD_CCtx* cctx)
 {
-    size_t hintInSize = cctx->inBuffTarget - cctx->inBuffPos;
-    if (hintInSize==0) hintInSize = cctx->blockSize;
-    return hintInSize;
+    if (cctx->appliedParams.inBufferMode == ZSTD_bm_stable) {
+        return cctx->blockSize - cctx->stableIn_notConsumed;
+    }
+    assert(cctx->appliedParams.inBufferMode == ZSTD_bm_buffered);
+    {   size_t hintInSize = cctx->inBuffTarget - cctx->inBuffPos;
+        if (hintInSize==0) hintInSize = cctx->blockSize;
+        return hintInSize;
+    }
 }
 
 /** ZSTD_compressStream_generic():
@@ -5329,16 +5334,23 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs,
                                           ZSTD_inBuffer* input,
                                           ZSTD_EndDirective const flushMode)
 {
-    const char* const istart = (const char*)input->src;
+    const char* const istart = (assert(input != NULL), (const char*)input->src);
     const char* const iend = (istart != NULL) ? istart + input->size : istart;
     const char* ip = (istart != NULL) ? istart + input->pos : istart;
-    char* const ostart = (char*)output->dst;
+    char* const ostart = (assert(output != NULL), (char*)output->dst);
     char* const oend = (ostart != NULL) ? ostart + output->size : ostart;
     char* op = (ostart != NULL) ? ostart + output->pos : ostart;
     U32 someMoreWork = 1;
 
     /* check expectations */
-    DEBUGLOG(5, "ZSTD_compressStream_generic, flush=%i", (int)flushMode);
+    DEBUGLOG(5, "ZSTD_compressStream_generic, flush=%i, srcSize = %zu", (int)flushMode, input->size - input->pos);
+    assert(zcs != NULL);
+    if (zcs->appliedParams.inBufferMode == ZSTD_bm_stable) {
+        assert(input->pos >= zcs->stableIn_notConsumed);
+        input->pos -= zcs->stableIn_notConsumed;
+        ip -= zcs->stableIn_notConsumed;
+        zcs->stableIn_notConsumed = 0;
+    }
     if (zcs->appliedParams.inBufferMode == ZSTD_bm_buffered) {
         assert(zcs->inBuff != NULL);
         assert(zcs->inBuffSize > 0);
@@ -5347,8 +5359,10 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs,
         assert(zcs->outBuff !=  NULL);
         assert(zcs->outBuffSize > 0);
     }
-    assert(output->pos <= output->size);
+    if (input->src == NULL) assert(input->size == 0);
     assert(input->pos <= input->size);
+    if (output->dst == NULL) assert(output->size == 0);
+    assert(output->pos <= output->size);
     assert((U32)flushMode <= (U32)ZSTD_e_end);
 
     while (someMoreWork) {
@@ -5380,8 +5394,7 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs,
                                         zcs->inBuff + zcs->inBuffPos, toLoad,
                                         ip, iend-ip);
                 zcs->inBuffPos += loaded;
-                if (loaded != 0)
-                    ip += loaded;
+                if (ip) ip += loaded;
                 if ( (flushMode == ZSTD_e_continue)
                   && (zcs->inBuffPos < zcs->inBuffTarget) ) {
                     /* not enough input to fill full block : stop here */
@@ -5392,6 +5405,20 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs,
                     /* empty */
                     someMoreWork = 0; break;
                 }
+            } else {
+                assert(zcs->appliedParams.inBufferMode == ZSTD_bm_stable);
+                if ( (flushMode == ZSTD_e_continue)
+                  && ( (size_t)(iend - ip) < zcs->blockSize) ) {
+                    /* can't compress a full block : stop here */
+                    zcs->stableIn_notConsumed = (size_t)(iend - ip);
+                    ip = iend;  /* pretend to have consumed input */
+                    someMoreWork = 0; break;
+                }
+                if ( (flushMode == ZSTD_e_flush)
+                  && (ip == iend) ) {
+                    /* empty */
+                    someMoreWork = 0; break;
+                }
             }
             /* compress current block (note : this stage cannot be stopped in the middle) */
             DEBUGLOG(5, "stream compression stage (flushMode==%u)", flushMode);
@@ -5399,9 +5426,8 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs,
                 void* cDst;
                 size_t cSize;
                 size_t oSize = oend-op;
-                size_t const iSize = inputBuffered
-                    ? zcs->inBuffPos - zcs->inToCompress
-                    : MIN((size_t)(iend - ip), zcs->blockSize);
+                size_t const iSize = inputBuffered ? zcs->inBuffPos - zcs->inToCompress
+                                                   : MIN((size_t)(iend - ip), zcs->blockSize);
                 if (oSize >= ZSTD_compressBound(iSize) || zcs->appliedParams.outBufferMode == ZSTD_bm_stable)
                     cDst = op;   /* compress into output buffer, to skip flush stage */
                 else
@@ -5425,17 +5451,15 @@ static size_t ZSTD_compressStream_generic(ZSTD_CStream* zcs,
                         assert(zcs->inBuffTarget <= zcs->inBuffSize);
                     zcs->inToCompress = zcs->inBuffPos;
                 } else { /* !inputBuffered, hence ZSTD_bm_stable */
-                    unsigned const lastBlock = (ip + iSize == iend);
+                    unsigned const lastBlock = (flushMode == ZSTD_e_end) && (ip + iSize == iend);
                     cSize = lastBlock ?
                             ZSTD_compressEnd(zcs, cDst, oSize, ip, iSize) :
                             ZSTD_compressContinue(zcs, cDst, oSize, ip, iSize);
                     /* Consume the input prior to error checking to mirror buffered mode. */
-                    if (iSize > 0)
-                        ip += iSize;
+                    if (ip) ip += iSize;
                     FORWARD_IF_ERROR(cSize, "%s", lastBlock ? "ZSTD_compressEnd failed" : "ZSTD_compressContinue failed");
                     zcs->frameEnded = lastBlock;
-                    if (lastBlock)
-                        assert(ip == iend);
+                    if (lastBlock) assert(ip == iend);
                 }
                 if (cDst == op) {  /* no need to flush */
                     op += cSize;
@@ -5514,6 +5538,7 @@ size_t ZSTD_compressStream(ZSTD_CStream* zcs, ZSTD_outBuffer* output, ZSTD_inBuf
 static void
 ZSTD_setBufferExpectations(ZSTD_CCtx* cctx, const ZSTD_outBuffer* output, const ZSTD_inBuffer* input)
 {
+    DEBUGLOG(5, "ZSTD_setBufferExpectations (for advanced stable in/out modes)");
     if (cctx->appliedParams.inBufferMode == ZSTD_bm_stable) {
         cctx->expectedInBuffer = *input;
     }
@@ -5647,27 +5672,25 @@ size_t ZSTD_compressStream2( ZSTD_CCtx* cctx,
     /* transparent initialization stage */
     if (cctx->streamStage == zcss_init) {
         size_t const inputSize = input->size - input->pos;  /* no obligation to start from pos==0 */
-        size_t const totalInputSize = (cctx->savedInPosPlusOne == 0) ? inputSize : input->size - (cctx->savedInPosPlusOne - 1);
+        size_t const totalInputSize = inputSize + cctx->stableIn_notConsumed;
         if ( (cctx->requestedParams.inBufferMode == ZSTD_bm_stable) /* input is presumed stable, across invocations */
           && (endOp == ZSTD_e_continue)                             /* no flush requested, more input to come */
           && (totalInputSize < ZSTD_BLOCKSIZE_MAX) ) {              /* not even reached one block yet */
-            if (cctx->savedInPosPlusOne) {  /* not the first time */
+            if (cctx->stableIn_notConsumed) {  /* not the first time */
                 /* check stable source guarantees */
                 assert(input->src == cctx->expectedInBuffer.src);
                 assert(input->pos == cctx->expectedInBuffer.size);
             }
-            /* keep track of first position */
-            if (cctx->savedInPosPlusOne == 0) cctx->savedInPosPlusOne = input->pos + 1;
-            cctx->expectedInBuffer = *input;
             /* pretend input was consumed, to give a sense forward progress */
             input->pos = input->size;
+            /* save stable inBuffer, for later control, and flush/end */
+            cctx->expectedInBuffer = *input;
             /* but actually input wasn't consumed, so keep track of position from where compression shall resume */
-            cctx->expectedInBuffer.pos = cctx->savedInPosPlusOne - 1;
+            cctx->stableIn_notConsumed += inputSize;
             /* don't initialize yet, wait for the first block of flush() order, for better parameters adaptation */
             return ZSTD_FRAMEHEADERSIZE_MIN(cctx->requestedParams.format);  /* at least some header to produce */
         }
         FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, endOp, totalInputSize), "compressStream2 initialization failed");
-        cctx->savedInPosPlusOne = 0;
         ZSTD_setBufferExpectations(cctx, output, input);   /* Set initial buffer expectations now that we've initialized */
     }
     /* end of transparent initialization stage */
@@ -5681,6 +5704,13 @@ size_t ZSTD_compressStream2( ZSTD_CCtx* cctx,
             ZSTDMT_updateCParams_whileCompressing(cctx->mtctx, &cctx->requestedParams);
             cctx->cParamsChanged = 0;
         }
+        if (cctx->stableIn_notConsumed) {
+            assert(cctx->appliedParams.inBufferMode == ZSTD_bm_stable);
+            /* some early data was skipped - make it available for consumption */
+            assert(input->pos >= cctx->stableIn_notConsumed);
+            input->pos -= cctx->stableIn_notConsumed;
+            cctx->stableIn_notConsumed = 0;
+        }
         for (;;) {
             size_t const ipos = input->pos;
             size_t const opos = output->pos;
index 8cc2f81d6d0aa549df2bdbb068cf82b37175868c..efbf89ae37bdcf69dc8ae43a5566d1363645b2b2 100644 (file)
@@ -410,8 +410,8 @@ struct ZSTD_CCtx_s {
 
     /* Stable in/out buffer verification */
     ZSTD_inBuffer expectedInBuffer;
+    size_t stableIn_notConsumed; /* nb bytes within stable input buffer that are said to be consumed but are not */
     size_t expectedOutBufferSize;
-    size_t savedInPosPlusOne;  /* 0 == no savedInPos */
 
     /* Dictionary */
     ZSTD_localDict localDict;
index 823db775f354a32a61e01bd92e2a0d84f47c5ff0..ddb2ad3938b088034f600fdfabab04321ee031da 100644 (file)
@@ -1206,7 +1206,7 @@ static int basicUnitTests(U32 const seed, double compressibility)
 
     DISPLAYLEVEL(3, "test%3i : compress a NULL input with each level : ", testNb++);
     {   int level = -1;
-        ZSTD_CCtx* cctx = ZSTD_createCCtx();
+        ZSTD_CCtx* const cctx = ZSTD_createCCtx();
         if (!cctx) goto _output_error;
         for (level = -1; level <= ZSTD_maxCLevel(); ++level) {
           CHECK_Z( ZSTD_compress(compressedBuffer, compressedBufferSize, NULL, 0, level) );
index 1444d27df9a9f63e14a1139dbefb12ecc842f342..e084924b2d0dc9d8a32f2a2e7de5c49cbe018550 100644 (file)
@@ -878,20 +878,26 @@ static int basicUnitTests(U32 seed, double compressibility)
         CHECK_Z( ZSTD_CCtx_setParameter(cctx, ZSTD_c_stableInBuffer, 1) );
         {   ZSTD_inBuffer inBuf;
             ZSTD_outBuffer outBuf;
+            const size_t inputSize = 500;
             inBuf.src = CNBuffer;
             inBuf.size = 100;
             inBuf.pos = 0;
             outBuf.dst = (char*)(compressedBuffer)+cSize;
-            outBuf.size = ZSTD_compressBound(500);
+            outBuf.size = ZSTD_compressBound(inputSize);
             outBuf.pos = 0;
             CHECK_Z( ZSTD_compressStream(cctx, &outBuf, &inBuf) );
             inBuf.size = 200;
             CHECK_Z( ZSTD_compressStream(cctx, &outBuf, &inBuf) );
             CHECK_Z( ZSTD_flushStream(cctx, &outBuf) );
-            inBuf.size = 300;
+            inBuf.size = inputSize;
             CHECK_Z( ZSTD_compressStream(cctx, &outBuf, &inBuf) );
             CHECK(ZSTD_endStream(cctx, &outBuf) != 0, "compression should be successful and fully flushed");
-        }
+            {   void* const verifBuf = (char*)outBuf.dst + outBuf.pos;
+                const size_t decSize = ZSTD_decompress(verifBuf, inputSize, outBuf.dst, outBuf.pos);
+                CHECK_Z(decSize);
+                CHECK(decSize != inputSize, "regenerated %zu bytes, instead of %zu", decSize, inputSize);
+                CHECK(memcmp(inBuf.src, verifBuf, inputSize) != 0, "regenerated data different from original");
+        }   }
         DISPLAYLEVEL(3, "OK \n");
 
         DISPLAYLEVEL(3, "test%3i : ZSTD_compressStream2() with ZSTD_c_stableInBuffer: context size : ", testNb++);