]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Add support for in-place decompression
authorNick Terrell <terrelln@fb.com>
Thu, 12 Jan 2023 02:14:40 +0000 (18:14 -0800)
committerNick Terrell <nickrterrell@gmail.com>
Fri, 13 Jan 2023 00:28:08 +0000 (16:28 -0800)
* Add a function and macro ZSTD_decompressionMargin() that computes the
  decompression margin for in-place decompression. The function computes
  a tight margin that works in all cases, and the macro computes an upper
  bound that will only work if flush isn't used.
* When doing in-place decompression, make sure that our output buffer
  doesn't overlap with the input buffer. This ensures that we don't
  decide to use the portion of the output buffer that overlaps the input
  buffer for temporary memory, like for literals.
* Add a simple unit test.
* Add in-place decompression to the simple_round_trip and
  stream_round_trip fuzzers. This should help verify that our margin stays
  correct.

lib/common/zstd_internal.h
lib/decompress/zstd_decompress.c
lib/legacy/zstd_legacy.h
lib/zstd.h
tests/fuzz/simple_round_trip.c
tests/fuzz/stream_round_trip.c
tests/fuzzer.c

index 48558873d62f41151c6960734a21562f95dc6eb5..12e1106a1e84d9e191f765d7b82b50d5f6728019 100644 (file)
@@ -341,6 +341,7 @@ MEM_STATIC ZSTD_sequenceLength ZSTD_getSequenceLength(seqStore_t const* seqStore
  *          `decompressedBound != ZSTD_CONTENTSIZE_ERROR`
  */
 typedef struct {
+    size_t nbBlocks;
     size_t compressedSize;
     unsigned long long decompressedBound;
 } ZSTD_frameSizeInfo;   /* decompress & legacy */
index f00ef3a67aeae98242c857c6650438eb1c3603e3..4559451d00c5a6d820673497738ec2979433292a 100644 (file)
@@ -782,6 +782,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize
             ip += 4;
         }
 
+        frameSizeInfo.nbBlocks = nbBlocks;
         frameSizeInfo.compressedSize = (size_t)(ip - ipstart);
         frameSizeInfo.decompressedBound = (zfh.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN)
                                         ? zfh.frameContentSize
@@ -825,6 +826,48 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize)
     return bound;
 }
 
+size_t ZSTD_decompressionMargin(void const* src, size_t srcSize)
+{
+    size_t margin = 0;
+    unsigned maxBlockSize = 0;
+
+    /* Iterate over each frame */
+    while (srcSize > 0) {
+        ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
+        size_t const compressedSize = frameSizeInfo.compressedSize;
+        unsigned long long const decompressedBound = frameSizeInfo.decompressedBound;
+        ZSTD_frameHeader zfh;
+
+        FORWARD_IF_ERROR(ZSTD_getFrameHeader(&zfh, src, srcSize), "");
+        if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR)
+            return ERROR(corruption_detected);
+
+        if (zfh.frameType == ZSTD_frame) {
+            /* Add the frame header to our margin */
+            margin += zfh.headerSize;
+            /* Add the checksum to our margin */
+            margin += zfh.checksumFlag ? 4 : 0;
+            /* Add 3 bytes per block */
+            margin += 3 * frameSizeInfo.nbBlocks;
+
+            /* Compute the max block size */
+            maxBlockSize = MAX(maxBlockSize, zfh.blockSizeMax);
+        } else {
+            assert(zfh.frameType == ZSTD_skippableFrame);
+            /* Add the entire skippable frame size to our margin. */
+            margin += compressedSize;
+        }
+
+        assert(srcSize >= compressedSize);
+        src = (const BYTE*)src + compressedSize;
+        srcSize -= compressedSize;
+    }
+
+    /* Add the max block size back to the margin. */
+    margin += maxBlockSize;
+
+    return margin;
+}
 
 /*-*************************************************************
  *   Frame decoding
@@ -850,7 +893,7 @@ static size_t ZSTD_copyRawBlock(void* dst, size_t dstCapacity,
         if (srcSize == 0) return 0;
         RETURN_ERROR(dstBuffer_null, "");
     }
-    ZSTD_memcpy(dst, src, srcSize);
+    ZSTD_memmove(dst, src, srcSize);
     return srcSize;
 }
 
@@ -928,6 +971,7 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
 
     /* Loop on each block */
     while (1) {
+        BYTE* oBlockEnd = oend;
         size_t decodedSize;
         blockProperties_t blockProperties;
         size_t const cBlockSize = ZSTD_getcBlockSize(ip, remainingSrcSize, &blockProperties);
@@ -937,16 +981,34 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
         remainingSrcSize -= ZSTD_blockHeaderSize;
         RETURN_ERROR_IF(cBlockSize > remainingSrcSize, srcSize_wrong, "");
 
+        if (ip >= op && ip < oBlockEnd) {
+            /* We are decompressing in-place. Limit the output pointer so that we
+             * don't overwrite the block that we are currently reading. This will
+             * fail decompression if the input & output pointers aren't spaced
+             * far enough apart.
+             *
+             * This is important to set, even when the pointers are far enough
+             * apart, because ZSTD_decompressBlock_internal() can decide to store
+             * literals in the output buffer, after the block it is decompressing.
+             * Since we don't want anything to overwrite our input, we have to tell
+             * ZSTD_decompressBlock_internal to never write past ip.
+             *
+             * See ZSTD_allocateLiteralsBuffer() for reference.
+             */
+            oBlockEnd = op + (ip - op);
+        }
+
         switch(blockProperties.blockType)
         {
         case bt_compressed:
-            decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oend-op), ip, cBlockSize, /* frame */ 1, not_streaming);
+            decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, /* frame */ 1, not_streaming);
             break;
         case bt_raw :
+            /* Use oend instead of oBlockEnd because this function is safe to overlap. It uses memmove. */
             decodedSize = ZSTD_copyRawBlock(op, (size_t)(oend-op), ip, cBlockSize);
             break;
         case bt_rle :
-            decodedSize = ZSTD_setRleBlock(op, (size_t)(oend-op), *ip, blockProperties.origSize);
+            decodedSize = ZSTD_setRleBlock(op, (size_t)(oBlockEnd-op), *ip, blockProperties.origSize);
             break;
         case bt_reserved :
         default:
index 9f53d4cbd65a05f3f7cbb7b2e7e15af31facf457..dd173251d34d5fc8cea084c09b568d29a03b08bb 100644 (file)
@@ -242,6 +242,13 @@ MEM_STATIC ZSTD_frameSizeInfo ZSTD_findFrameSizeInfoLegacy(const void *src, size
         frameSizeInfo.compressedSize = ERROR(srcSize_wrong);
         frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR;
     }
+    /* In all cases, decompressedBound == nbBlocks * ZSTD_BLOCKSIZE_MAX.
+     * So we can compute nbBlocks without having to change every function.
+     */
+    if (frameSizeInfo.decompressedBound != ZSTD_CONTENTSIZE_ERROR) {
+        assert((frameSizeInfo.decompressedBound & (ZSTD_BLOCKSIZE_MAX - 1)) == 0);
+        frameSizeInfo.nbBlocks = (size_t)(frameSizeInfo.decompressedBound / ZSTD_BLOCKSIZE_MAX);
+    }
     return frameSizeInfo;
 }
 
index 480d65f675e9482ae86a8a91d6c96fefbf855f46..22c8bba5b59dba00aa542e3df44f287a5647625f 100644 (file)
@@ -1427,6 +1427,51 @@ ZSTDLIB_STATIC_API unsigned long long ZSTD_decompressBound(const void* src, size
  *           or an error code (if srcSize is too small) */
 ZSTDLIB_STATIC_API size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize);
 
+/*! ZSTD_decompressionMargin() :
+ * Zstd supports in-place decompression, where the input and output buffers overlap.
+ * In this case, the output buffer must be at least (Margin + Output_Size) bytes large,
+ * and the input buffer must be at the end of the output buffer.
+ *
+ *  _______________________ Output Buffer ________________________
+ * |                                                              |
+ * |                                        ____ Input Buffer ____|
+ * |                                       |                      |
+ * v                                       v                      v
+ * |---------------------------------------|-----------|----------|
+ * ^                                                   ^          ^
+ * |___________________ Output_Size ___________________|_ Margin _|
+ *
+ * NOTE: See also ZSTD_DECOMPRESSION_MARGIN().
+ * NOTE: This applies only to single-pass decompression through ZSTD_decompress() or
+ * ZSTD_decompressDCtx().
+ * NOTE: This function supports multi-frame input.
+ *
+ * @param src The compressed frame(s)
+ * @param srcSize The size of the compressed frame(s)
+ * @returns The decompression margin or an error that can be checked with ZSTD_isError().
+ */
+ZSTDLIB_STATIC_API size_t ZSTD_decompressionMargin(const void* src, size_t srcSize);
+
+/*! ZSTD_DECOMPRESS_MARGIN() :
+ * Similar to ZSTD_decompressionMargin(), but instead of computing the margin from
+ * the compressed frame, compute it from the original size and the blockSizeLog.
+ * See ZSTD_decompressionMargin() for details.
+ *
+ * WARNING: This macro does not support multi-frame input, the input must be a single
+ * zstd frame. If you need that support use the function, or implement it yourself.
+ *
+ * @param originalSize The original uncompressed size of the data.
+ * @param blockSize    The block size == MIN(windowSize, ZSTD_BLOCKSIZE_MAX).
+ *                     Unless you explicitly set the windowLog smaller than
+ *                     ZSTD_BLOCKSIZELOG_MAX you can just use ZSTD_BLOCKSIZE_MAX.
+ */
+#define ZSTD_DECOMPRESSION_MARGIN(originalSize, blockSize) ((size_t)(                                              \
+        ZSTD_FRAMEHEADERSIZE_MAX                                                              /* Frame header */ + \
+        4                                                                                         /* checksum */ + \
+        ((originalSize) == 0 ? 0 : 3 * (((originalSize) + (blockSize) - 1) / blockSize)) /* 3 bytes per block */ + \
+        (blockSize)                                                                    /* One block of margin */   \
+    ))
+
 typedef enum {
   ZSTD_sf_noBlockDelimiters = 0,         /* Representation of ZSTD_Sequence has no block delimiters, sequences only */
   ZSTD_sf_explicitBlockDelimiters = 1    /* Representation of ZSTD_Sequence contains explicit block delimiters */
index 23a805af219781e80785815d9229a3e77a3c8c3f..c2c69d950b6d22b77de06f901dbb0d528eeb729f 100644 (file)
 static ZSTD_CCtx *cctx = NULL;
 static ZSTD_DCtx *dctx = NULL;
 
+static size_t getDecompressionMargin(void const* compressed, size_t cSize, size_t srcSize, int hasSmallBlocks)
+{
+    size_t margin = ZSTD_decompressionMargin(compressed, cSize);
+    if (!hasSmallBlocks) {
+        /* The macro should be correct in this case, but it may be smaller
+         * because of e.g. block splitting, so take the smaller of the two.
+         */
+        ZSTD_frameHeader zfh;
+        size_t marginM;
+        FUZZ_ZASSERT(ZSTD_getFrameHeader(&zfh, compressed, cSize));
+        marginM = ZSTD_DECOMPRESSION_MARGIN(srcSize, zfh.blockSizeMax);
+        if (marginM < margin)
+            margin = marginM;
+    }
+    return margin;
+}
+
 static size_t roundTripTest(void *result, size_t resultCapacity,
                             void *compressed, size_t compressedCapacity,
                             const void *src, size_t srcSize,
@@ -67,6 +84,25 @@ static size_t roundTripTest(void *result, size_t resultCapacity,
     }
     dSize = ZSTD_decompressDCtx(dctx, result, resultCapacity, compressed, cSize);
     FUZZ_ZASSERT(dSize);
+    FUZZ_ASSERT_MSG(dSize == srcSize, "Incorrect regenerated size");
+    FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, result, dSize), "Corruption!");
+
+    {
+        size_t margin = getDecompressionMargin(compressed, cSize, srcSize, targetCBlockSize);
+        size_t const outputSize = srcSize + margin;
+        char* const output = (char*)FUZZ_malloc(outputSize);
+        char* const input = output + outputSize - cSize;
+        FUZZ_ASSERT(outputSize >= cSize);
+        memcpy(input, compressed, cSize);
+
+        dSize = ZSTD_decompressDCtx(dctx, output, outputSize, input, cSize);
+        FUZZ_ZASSERT(dSize);
+        FUZZ_ASSERT_MSG(dSize == srcSize, "Incorrect regenerated size");
+        FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, output, srcSize), "Corruption!");
+
+        free(output);
+    }
+
     /* When superblock is enabled make sure we don't expand the block more than expected.
      * NOTE: This test is currently disabled because superblock mode can arbitrarily
      * expand the block in the worst case. Once superblock mode has been improved we can
@@ -120,13 +156,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size)
         FUZZ_ASSERT(dctx);
     }
 
-    {
-        size_t const result =
-            roundTripTest(rBuf, rBufSize, cBuf, cBufSize, src, size, producer);
-        FUZZ_ZASSERT(result);
-        FUZZ_ASSERT_MSG(result == size, "Incorrect regenerated size");
-        FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, rBuf, size), "Corruption!");
-    }
+    roundTripTest(rBuf, rBufSize, cBuf, cBufSize, src, size, producer);
     free(rBuf);
     free(cBuf);
     FUZZ_dataProducer_free(producer);
index 8a28907b6ff1d590d904c130d61fd83de99b2aef..fae9ccbf498092ae45c14077eca36e7065af4d71 100644 (file)
@@ -166,6 +166,24 @@ int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size)
         FUZZ_ZASSERT(rSize);
         FUZZ_ASSERT_MSG(rSize == size, "Incorrect regenerated size");
         FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, rBuf, size), "Corruption!");
+
+        /* Test in-place decompression (note the macro doesn't work in this case) */
+        {
+            size_t const margin = ZSTD_decompressionMargin(cBuf, cSize);
+            size_t const outputSize = size + margin;
+            char* const output = (char*)FUZZ_malloc(outputSize);
+            char* const input = output + outputSize - cSize;
+            size_t dSize;
+            FUZZ_ASSERT(outputSize >= cSize);
+            memcpy(input, cBuf, cSize);
+
+            dSize = ZSTD_decompressDCtx(dctx, output, outputSize, input, cSize);
+            FUZZ_ZASSERT(dSize);
+            FUZZ_ASSERT_MSG(dSize == size, "Incorrect regenerated size");
+            FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, output, size), "Corruption!");
+
+            free(output);
+        }
     }
 
     FUZZ_dataProducer_free(producer);
index 3ad8ced5e5df7acb418566d9c75f5c35cc63c1c7..ce0bea57348ecf1f2f983c6724892f1dbc809187 100644 (file)
@@ -1220,6 +1220,60 @@ static int basicUnitTests(U32 const seed, double compressibility)
     }
     DISPLAYLEVEL(3, "OK \n");
 
+    DISPLAYLEVEL(3, "test%3i : in-place decompression : ", testNb++);
+    cSize = ZSTD_compress(compressedBuffer, compressedBufferSize, CNBuffer, CNBuffSize, -ZSTD_BLOCKSIZE_MAX);
+    CHECK_Z(cSize);
+    CHECK_LT(CNBuffSize, cSize);
+    {
+        size_t const margin = ZSTD_decompressionMargin(compressedBuffer, cSize);
+        size_t const outputSize = (CNBuffSize + margin);
+        char* output = malloc(outputSize);
+        char* input = output + outputSize - cSize;
+        CHECK_LT(cSize, CNBuffSize + margin);
+        CHECK(output != NULL);
+        CHECK_Z(margin);
+        CHECK(margin <= ZSTD_DECOMPRESSION_MARGIN(CNBuffSize, ZSTD_BLOCKSIZE_MAX));
+        memcpy(input, compressedBuffer, cSize);
+
+        {
+            size_t const dSize = ZSTD_decompress(output, outputSize, input, cSize);
+            CHECK_Z(dSize);
+            CHECK_EQ(dSize, CNBuffSize);
+        }
+        CHECK(!memcmp(output, CNBuffer, CNBuffSize));
+        free(output);
+    }
+    DISPLAYLEVEL(3, "OK \n");
+
+    DISPLAYLEVEL(3, "test%3i : in-place decompression with 2 frames : ", testNb++);
+    cSize = ZSTD_compress(compressedBuffer, compressedBufferSize, CNBuffer, CNBuffSize / 3, -ZSTD_BLOCKSIZE_MAX);
+    CHECK_Z(cSize);
+    {
+        size_t const cSize2 = ZSTD_compress((char*)compressedBuffer + cSize, compressedBufferSize - cSize, (char const*)CNBuffer + (CNBuffSize / 3), CNBuffSize / 3, -ZSTD_BLOCKSIZE_MAX);
+        CHECK_Z(cSize2);
+        cSize += cSize2;
+    }
+    {
+        size_t const srcSize = (CNBuffSize / 3) * 2;
+        size_t const margin = ZSTD_decompressionMargin(compressedBuffer, cSize);
+        size_t const outputSize = (CNBuffSize + margin);
+        char* output = malloc(outputSize);
+        char* input = output + outputSize - cSize;
+        CHECK_LT(cSize, CNBuffSize + margin);
+        CHECK(output != NULL);
+        CHECK_Z(margin);
+        memcpy(input, compressedBuffer, cSize);
+
+        {
+            size_t const dSize = ZSTD_decompress(output, outputSize, input, cSize);
+            CHECK_Z(dSize);
+            CHECK_EQ(dSize, srcSize);
+        }
+        CHECK(!memcmp(output, CNBuffer, srcSize));
+        free(output);
+    }
+    DISPLAYLEVEL(3, "OK \n");
+
     DISPLAYLEVEL(3, "test%3d: superblock uncompressible data, too many nocompress superblocks : ", testNb++);
     {
         ZSTD_CCtx* const cctx = ZSTD_createCCtx();