]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
[lib] Add ZSTD_d_stableOutBuffer
authorNick Terrell <terrelln@fb.com>
Tue, 28 Apr 2020 00:42:03 +0000 (17:42 -0700)
committerNick Terrell <terrelln@fb.com>
Tue, 28 Apr 2020 01:09:44 +0000 (18:09 -0700)
lib/common/error_private.c
lib/common/zstd_errors.h
lib/decompress/zstd_decompress.c
lib/decompress/zstd_decompress_internal.h
lib/zstd.h
tests/zstreamtest.c

index 3920584409136d9a74e78f230327ffaec3cf1a46..cd437529c12bff0a4577e5fd3aeba3f724cc2778 100644 (file)
@@ -47,6 +47,7 @@ const char* ERR_getErrorString(ERR_enum code)
         /* following error codes are not stable and may be removed or changed in a future version */
     case PREFIX(frameIndex_tooLarge): return "Frame index is too large";
     case PREFIX(seekableIO): return "An I/O error occurred when reading/seeking";
+    case PREFIX(dstBuffer_wrong): return "Destination buffer is wrong";
     case PREFIX(maxCode):
     default: return notErrorCode;
     }
index 238309457b3f1c61f48e2a626f917aa8ad86dd3f..998398e7e57f299cbcf5142cb9bf11ce16d3b0f2 100644 (file)
@@ -76,6 +76,7 @@ typedef enum {
   /* following error codes are __NOT STABLE__, they can be removed or changed in future versions */
   ZSTD_error_frameIndex_tooLarge = 100,
   ZSTD_error_seekableIO          = 102,
+  ZSTD_error_dstBuffer_wrong     = 104,
   ZSTD_error_maxCode = 120  /* never EVER use this value directly, it can change in future versions! Use ZSTD_isError() instead */
 } ZSTD_ErrorCode;
 
index 6ef156b658ddc91e38ef5ef4f80ffcb36b45dcf2..9cd637354acd9302f2c2c73d6a4f98859d14a33d 100644 (file)
@@ -113,6 +113,7 @@ static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx)
     dctx->noForwardProgress = 0;
     dctx->oversizedDuration = 0;
     dctx->bmi2 = ZSTD_cpuid_bmi2(ZSTD_cpuid());
+    dctx->outBufferMode = ZSTD_obm_buffered;
 }
 
 ZSTD_DCtx* ZSTD_initStaticDCtx(void *workspace, size_t workspaceSize)
@@ -1402,6 +1403,10 @@ ZSTD_bounds ZSTD_dParam_getBounds(ZSTD_dParameter dParam)
             bounds.upperBound = (int)ZSTD_f_zstd1_magicless;
             ZSTD_STATIC_ASSERT(ZSTD_f_zstd1 < ZSTD_f_zstd1_magicless);
             return bounds;
+        case ZSTD_d_stableOutBuffer:
+            bounds.lowerBound = (int)ZSTD_obm_buffered;
+            bounds.upperBound = (int)ZSTD_obm_stable;
+            return bounds;
         default:;
     }
     bounds.error = ERROR(parameter_unsupported);
@@ -1437,6 +1442,10 @@ size_t ZSTD_DCtx_setParameter(ZSTD_DCtx* dctx, ZSTD_dParameter dParam, int value
             CHECK_DBOUNDS(ZSTD_d_format, value);
             dctx->format = (ZSTD_format_e)value;
             return 0;
+        case ZSTD_d_stableOutBuffer:
+            CHECK_DBOUNDS(ZSTD_d_stableOutBuffer, value);
+            dctx->outBufferMode = (ZSTD_outBufferMode_e)value;
+            return 0;
         default:;
     }
     RETURN_ERROR(parameter_unsupported);
@@ -1517,6 +1526,50 @@ static int ZSTD_DCtx_isOversizedTooLong(ZSTD_DStream* zds)
     return zds->oversizedDuration >= ZSTD_WORKSPACETOOLARGE_MAXDURATION;
 }
 
+static size_t ZSTD_checkOutBuffer(ZSTD_DStream const* zds, ZSTD_outBuffer const* output)
+{
+    ZSTD_outBuffer const expect = zds->expectedOutBuffer;
+    /* No requirement when ZSTD_obm_stable is not enabled. */
+    if (zds->outBufferMode != ZSTD_obm_stable)
+        return 0;
+    /* Any buffer is allowed in zdss_init, this must be the same for every other call until
+     * the context is reset.
+     */
+    if (zds->streamStage == zdss_init)
+        return 0;
+    /* The buffer must match our expectation exactly. */
+    if (expect.dst == output->dst && expect.pos == output->pos && expect.size == output->size)
+        return 0;
+    RETURN_ERROR(dstBuffer_wrong, "ZSTD_obm_stable enabled but output differs!");
+}
+
+static size_t ZSTD_decompressContinueStream(
+            ZSTD_DStream* zds, char** op, char* oend,
+            void const* src, size_t srcSize) {
+    int const isSkipFrame = ZSTD_isSkipFrame(zds);
+    if (zds->outBufferMode == ZSTD_obm_buffered) {
+        size_t const dstSize = isSkipFrame ? 0 : zds->outBuffSize - zds->outStart;
+        size_t const decodedSize = ZSTD_decompressContinue(zds,
+                zds->outBuff + zds->outStart, dstSize, src, srcSize);
+        FORWARD_IF_ERROR(decodedSize);
+        if (!decodedSize && !isSkipFrame) {
+            zds->streamStage = zdss_read;
+        } else {
+            zds->outEnd = zds->outStart + decodedSize;
+            zds->streamStage = zdss_flush;
+        }
+    } else {
+        size_t const dstSize = isSkipFrame ? 0 : oend - *op;
+        size_t const decodedSize = ZSTD_decompressContinue(zds, *op, dstSize, src, srcSize);
+        FORWARD_IF_ERROR(decodedSize);
+        *op += decodedSize;
+        zds->streamStage = zdss_read;
+        assert(*op <= oend);
+        assert(zds->outBufferMode == ZSTD_obm_stable);
+    }
+    return 0;
+}
+
 size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inBuffer* input)
 {
     const char* const src = (const char*)input->src;
@@ -1541,6 +1594,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
         "forbidden. out: pos: %u   vs size: %u",
         (U32)output->pos, (U32)output->size);
     DEBUGLOG(5, "input size : %u", (U32)(input->size - input->pos));
+    FORWARD_IF_ERROR(ZSTD_checkOutBuffer(zds, output));
 
     while (someMoreWork) {
         switch(zds->streamStage)
@@ -1551,6 +1605,7 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
             zds->lhSize = zds->inPos = zds->outStart = zds->outEnd = 0;
             zds->legacyVersion = 0;
             zds->hostageByte = 0;
+            zds->expectedOutBuffer = *output;
             /* fall-through */
 
         case zdss_loadHeader :
@@ -1621,6 +1676,14 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
                     break;
             }   }
 
+            /* Check output buffer is large enough for ZSTD_odm_stable. */
+            if (zds->outBufferMode == ZSTD_obm_stable
+                && zds->fParams.frameType != ZSTD_skippableFrame
+                && zds->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN
+                && (U64)(size_t)(oend-op) < zds->fParams.frameContentSize) {
+                RETURN_ERROR(dstSize_tooSmall, "ZSTD_obm_stable passed but ZSTD_outBuffer is too small");
+            }
+
             /* Consume header (see ZSTDds_decodeFrameHeader) */
             DEBUGLOG(4, "Consume header");
             FORWARD_IF_ERROR(ZSTD_decompressBegin_usingDDict(zds, ZSTD_getDDict(zds)));
@@ -1644,7 +1707,9 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
 
             /* Adapt buffer sizes to frame header instructions */
             {   size_t const neededInBuffSize = MAX(zds->fParams.blockSizeMax, 4 /* frame checksum */);
-                size_t const neededOutBuffSize = ZSTD_decodingBufferSize_min(zds->fParams.windowSize, zds->fParams.frameContentSize);
+                size_t const neededOutBuffSize = zds->outBufferMode == ZSTD_obm_buffered
+                        ? ZSTD_decodingBufferSize_min(zds->fParams.windowSize, zds->fParams.frameContentSize)
+                        : 0;
 
                 ZSTD_DCtx_updateOversizedDuration(zds, neededInBuffSize, neededOutBuffSize);
 
@@ -1687,15 +1752,9 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
                     break;
                 }
                 if ((size_t)(iend-ip) >= neededInSize) {  /* decode directly from src */
-                    int const isSkipFrame = ZSTD_isSkipFrame(zds);
-                    size_t const decodedSize = ZSTD_decompressContinue(zds,
-                        zds->outBuff + zds->outStart, (isSkipFrame ? 0 : zds->outBuffSize - zds->outStart),
-                        ip, neededInSize);
-                    if (ZSTD_isError(decodedSize)) return decodedSize;
+                    FORWARD_IF_ERROR(ZSTD_decompressContinueStream(zds, &op, oend, ip, neededInSize));
                     ip += neededInSize;
-                    if (!decodedSize && !isSkipFrame) break;   /* this was just a header */
-                    zds->outEnd = zds->outStart + decodedSize;
-                    zds->streamStage = zdss_flush;
+                    /* Function modifies the stage so we must break */
                     break;
             }   }
             if (ip==iend) { someMoreWork = 0; break; }   /* no more input */
@@ -1722,17 +1781,11 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
                 if (loadedSize < toLoad) { someMoreWork = 0; break; }   /* not enough input, wait for more */
 
                 /* decode loaded input */
-                {   size_t const decodedSize = ZSTD_decompressContinue(zds,
-                        zds->outBuff + zds->outStart, zds->outBuffSize - zds->outStart,
-                        zds->inBuff, neededInSize);
-                    if (ZSTD_isError(decodedSize)) return decodedSize;
-                    zds->inPos = 0;   /* input is consumed */
-                    if (!decodedSize && !isSkipFrame) { zds->streamStage = zdss_read; break; }   /* this was just a header */
-                    zds->outEnd = zds->outStart +  decodedSize;
-            }   }
-            zds->streamStage = zdss_flush;
-            /* fall-through */
-
+                zds->inPos = 0;   /* input is consumed */
+                FORWARD_IF_ERROR(ZSTD_decompressContinueStream(zds, &op, oend, zds->inBuff, neededInSize));
+                /* Function modifies the stage so we must break */
+                break;
+            }
         case zdss_flush:
             {   size_t const toFlushSize = zds->outEnd - zds->outStart;
                 size_t const flushedSize = ZSTD_limitCopy(op, oend-op, zds->outBuff + zds->outStart, toFlushSize);
@@ -1761,6 +1814,10 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
     /* result */
     input->pos = (size_t)(ip - (const char*)(input->src));
     output->pos = (size_t)(op - (char*)(output->dst));
+
+    /* Update the expected output buffer for ZSTD_obm_stable. */
+    zds->expectedOutBuffer = *output;
+
     if ((ip==istart) && (op==ostart)) {  /* no forward progress */
         zds->noForwardProgress ++;
         if (zds->noForwardProgress >= ZSTD_NO_FORWARD_PROGRESS_MAX) {
index 29b4d0acc21b15f103058cf26e18919ca5dbbebe..f1c2585a662b952ff16e0d2d78e8f5e4200c04b8 100644 (file)
@@ -95,6 +95,11 @@ typedef enum {
     ZSTD_use_once = 1            /* Use the dictionary once and set to ZSTD_dont_use */
 } ZSTD_dictUses_e;
 
+typedef enum {
+    ZSTD_obm_buffered = 0,  /* Buffer the output */
+    ZSTD_obm_stable = 1     /* ZSTD_outBuffer is stable */
+} ZSTD_outBufferMode_e;
+
 struct ZSTD_DCtx_s
 {
     const ZSTD_seqSymbol* LLTptr;
@@ -147,6 +152,8 @@ struct ZSTD_DCtx_s
     U32 legacyVersion;
     U32 hostageByte;
     int noForwardProgress;
+    ZSTD_outBufferMode_e outBufferMode;
+    ZSTD_outBuffer expectedOutBuffer;
 
     /* workspace */
     BYTE litBuffer[ZSTD_BLOCKSIZE_MAX + WILDCOPY_OVERLENGTH];
index ae99b3039ba2d8cfebba7b3b6d0f8f13d75c5635..40e9fb162c3eeb9928b085d16eb52bfadaa9a184 100644 (file)
@@ -523,10 +523,12 @@ typedef enum {
      * within the experimental section of the API.
      * At the time of this writing, they include :
      * ZSTD_d_format
+     * ZSTD_d_stableOutBuffer
      * Because they are not stable, it's necessary to define ZSTD_STATIC_LINKING_ONLY to access them.
      * note : never ever use experimentalParam? names directly
      */
-     ZSTD_d_experimentalParam1=1000
+     ZSTD_d_experimentalParam1=1000,
+     ZSTD_d_experimentalParam2=1001
 
 } ZSTD_dParameter;
 
@@ -1645,6 +1647,29 @@ ZSTDLIB_API size_t ZSTD_DCtx_setMaxWindowSize(ZSTD_DCtx* dctx, size_t maxWindowS
  * allowing selection between ZSTD_format_e input compression formats
  */
 #define ZSTD_d_format ZSTD_d_experimentalParam1
+/* ZSTD_d_stableOutBuffer
+ * Experimental parameter.
+ * Default is 0 == disabled. Set to 1 to enable.
+ *
+ * Tells the decompressor that the ZSTD_outBuffer will ALWAYS be the same
+ * between calls, except for the modifications that zstd makes to pos (the
+ * caller must not modify pos). This is checked by the decompressor, and
+ * decompression will fail if it ever changes. Therefore the ZSTD_outBuffer
+ * MUST be large enough to fit the entire decompressed frame. This will be
+ * checked when the frame content size is known.
+ *
+ * When this flags is enabled zstd won't allocate an output buffer, because
+ * it can write directly to the ZSTD_outBuffer, but it will still allocate
+ * an input buffer large enough to fit any compressed block. This will also
+ * avoid the memcpy() from the internal output buffer to the ZSTD_outBuffer.
+ * If you need to avoid the input buffer allocation use the buffer-less
+ * streaming API.
+ *
+ * NOTE: So long as the ZSTD_outBuffer always points to valid memory, using
+ * this flag is ALWAYS memory safe, and will never access out-of-bounds
+ * memory. However, decompression WILL fail if you violate the preconditions.
+ */
+#define ZSTD_d_stableOutBuffer ZSTD_d_experimentalParam2
 
 /*! ZSTD_DCtx_setFormat() :
  *  Instruct the decoder context about what kind of data to decode next.
index a945125ac9952947891b8b44c1b981280786e652..002c16de95746ad328f751adbaff06834e6e98ac 100644 (file)
@@ -641,6 +641,92 @@ static int basicUnitTests(U32 seed, double compressibility)
     }
     DISPLAYLEVEL(3, "OK \n");
 
+    /* Decompression with ZSTD_d_stableOutBuffer */
+    cSize = ZSTD_compress(compressedBuffer, compressedBufferSize, CNBuffer, CNBufferSize, 1);
+    CHECK_Z(cSize);
+    {   ZSTD_DCtx* dctx = ZSTD_createDCtx();
+        size_t const dctxSize0 = ZSTD_sizeof_DCtx(dctx);        
+        size_t dctxSize1;
+        CHECK_Z(ZSTD_DCtx_setParameter(dctx, ZSTD_d_stableOutBuffer, 1));
+
+        outBuff.dst = decodedBuffer;
+        outBuff.pos = 0;
+        outBuff.size = CNBufferSize;
+
+        DISPLAYLEVEL(3, "test%3i : ZSTD_decompressStream() single pass : ", testNb++);
+        inBuff.src = compressedBuffer;
+        inBuff.size = cSize;
+        inBuff.pos = 0;
+        {   size_t const r = ZSTD_decompressStream(dctx, &outBuff, &inBuff);
+            CHECK_Z(r);
+            CHECK(r != 0, "Entire frame must be decompressed");
+            CHECK(outBuff.pos != CNBufferSize, "Wrong size!");
+            CHECK(memcmp(CNBuffer, outBuff.dst, CNBufferSize) != 0, "Corruption!");
+        }
+        CHECK(dctxSize0 != ZSTD_sizeof_DCtx(dctx), "No buffers allocated");
+        DISPLAYLEVEL(3, "OK \n");
+
+        DISPLAYLEVEL(3, "test%3i : ZSTD_decompressStream() stable out buffer : ", testNb++);
+        outBuff.pos = 0;
+        inBuff.pos = 0;
+        inBuff.size = 0;
+        while (inBuff.pos < cSize) {
+            inBuff.size += MIN(cSize - inBuff.pos, 1 + (FUZ_rand(&coreSeed) & 15));
+            CHECK_Z(ZSTD_decompressStream(dctx, &outBuff, &inBuff));
+        }
+        CHECK(outBuff.pos != CNBufferSize, "Wrong size!");
+        CHECK(memcmp(CNBuffer, outBuff.dst, CNBufferSize) != 0, "Corruption!");
+        dctxSize1 = ZSTD_sizeof_DCtx(dctx);
+        CHECK(!(dctxSize0 < dctxSize1), "Input buffer allocated");
+        DISPLAYLEVEL(3, "OK \n");
+
+        DISPLAYLEVEL(3, "test%3i : ZSTD_decompressStream() stable out buffer too small : ", testNb++);
+        ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
+        CHECK_Z(ZSTD_DCtx_setParameter(dctx, ZSTD_d_stableOutBuffer, 1));
+        inBuff.src = compressedBuffer;
+        inBuff.size = cSize;
+        inBuff.pos = 0;
+        outBuff.pos = 0;
+        outBuff.size = CNBufferSize - 1;
+        {   size_t const r = ZSTD_decompressStream(dctx, &outBuff, &inBuff);
+            CHECK(ZSTD_getErrorCode(r) != ZSTD_error_dstSize_tooSmall, "Must error but got %s", ZSTD_getErrorName(r));
+        }
+        DISPLAYLEVEL(3, "OK \n");
+
+        DISPLAYLEVEL(3, "test%3i : ZSTD_decompressStream() stable out buffer modified : ", testNb++);
+        ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
+        CHECK_Z(ZSTD_DCtx_setParameter(dctx, ZSTD_d_stableOutBuffer, 1));
+        inBuff.src = compressedBuffer;
+        inBuff.size = cSize - 1;
+        inBuff.pos = 0;
+        outBuff.pos = 0;
+        outBuff.size = CNBufferSize;
+        CHECK_Z(ZSTD_decompressStream(dctx, &outBuff, &inBuff));
+        ++inBuff.size;
+        outBuff.pos = 0;
+        {   size_t const r = ZSTD_decompressStream(dctx, &outBuff, &inBuff);
+            CHECK(ZSTD_getErrorCode(r) != ZSTD_error_dstBuffer_wrong, "Must error but got %s", ZSTD_getErrorName(r));
+        }
+        DISPLAYLEVEL(3, "OK \n");
+        
+        DISPLAYLEVEL(3, "test%3i : ZSTD_decompressStream() buffered output : ", testNb++);
+        ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
+        CHECK_Z(ZSTD_DCtx_setParameter(dctx, ZSTD_d_stableOutBuffer, 0));
+        outBuff.pos = 0;
+        inBuff.pos = 0;
+        inBuff.size = 0;
+        while (inBuff.pos < cSize) {
+            inBuff.size += MIN(cSize - inBuff.pos, 1 + (FUZ_rand(&coreSeed) & 15));
+            CHECK_Z(ZSTD_decompressStream(dctx, &outBuff, &inBuff));
+        }
+        CHECK(outBuff.pos != CNBufferSize, "Wrong size!");
+        CHECK(memcmp(CNBuffer, outBuff.dst, CNBufferSize) != 0, "Corruption!");
+        CHECK(!(dctxSize1 < ZSTD_sizeof_DCtx(dctx)), "Output buffer allocated");
+        DISPLAYLEVEL(3, "OK \n");
+
+        ZSTD_freeDCtx(dctx);
+    }
+
     /* CDict scenario */
     DISPLAYLEVEL(3, "test%3i : digested dictionary : ", testNb++);
     {   ZSTD_CDict* const cdict = ZSTD_createCDict(dictionary.start, dictionary.filled, 1 /*byRef*/ );