]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Overhaul logic to simplify, add in proper validations, fix match splitting
authorsenhuang42 <senhuang96@fb.com>
Mon, 16 Nov 2020 15:36:06 +0000 (10:36 -0500)
committersenhuang42 <senhuang96@fb.com>
Mon, 16 Nov 2020 15:49:17 +0000 (10:49 -0500)
lib/compress/zstd_compress.c

index e3e0fc601cefa8f59183778d1a85373746bb7e81..521752df01d388559972b2429713922565b94521 100644 (file)
@@ -4475,10 +4475,108 @@ size_t ZSTD_compress2(ZSTD_CCtx* cctx,
 }
 
 typedef struct {
-    U32 idx;             /* Index in array of ZSTD_Sequence*/
+    U32 idx;             /* Index in array of ZSTD_Sequence */
     U32 posInSequence;   /* Position within sequence at idx */
+    U64 posInSrc;        /* Position in src stream */
 } ZSTD_sequencePosition;
 
+#if defined(DEBUGLEVEL) && (DEBUGLEVEL >= 6)
+/* Returns a ZSTD error code if sequence is not valid */
+static size_t ZSTD_validateSequence(U32 offCode, U32 repCode, U32 matchLength,
+                                    size_t posInSrc, U32 windowLog, U32 dictSize) {
+    U32 offsetBound;
+    U32 windowSize = 1 << windowLog;
+    /* posInSrc represents the amount of data the the decoder would decode up to this point.
+     * As long as the amount of data decoded is less than or equal to window size, offsets may be
+     * larger than the total length of output decoded in order to reference the dict, even larger than
+     * window size. After output surpasses windowSize, we're limited to windowSize offsets again.
+     */
+    offsetBound = posInSrc > windowSize ? windowSize : posInSrc + dictSize;
+    RETURN_ERROR_IF(!repCode && offCode - ZSTD_REP_MOVE > offsetBound, corruption_detected, "Offset too large!");
+    RETURN_ERROR_IF(matchLength < MINMATCH, corruption_detected, "Matchlength too small");
+    return 0;
+}
+#endif
+
+/* Returns offset code, given a raw offset and repcode array */
+static U32 ZSTD_finalizeOffCode(U32 rawOffset, const U32* const rep, U32 ll0) {
+    U32 offCode = rawOffset + ZSTD_REP_MOVE;
+    U32 repCode = 0;
+
+    if (!ll0 && rawOffset == rep[0]) {
+        repCode = 1;
+    } else if (rawOffset == rep[1]) {
+        repCode = 2 - ll0;
+    } else if (rawOffset == rep[2]) {
+        repCode = 3 - ll0;
+    } else if (ll0 && rawOffset == rep[0] - 1) {
+        repCode = 3;
+    }
+    if (repCode) {
+        offCode = repCode - 1;
+    }
+    return offCode;
+}
+
+static size_t ZSTD_copySequencesToSeqStoreBlockDelim(seqStore_t* seqStore, ZSTD_sequencePosition* seqPos,
+                                           const ZSTD_Sequence* const inSeqs, size_t inSeqsSize,
+                                           const void* src, size_t blockSize, ZSTD_CCtx* cctx) {
+    size_t idx = seqPos->idx;
+    BYTE const* ip = (BYTE const*)(src);
+    BYTE const* iend = ip + blockSize;
+    repcodes_t updatedRepcodes;
+    U32 dictSize;
+    U32 litLength;
+    U32 matchLength;
+    U32 ll0;
+    U32 offCode;
+
+    if (cctx->cdict) {
+        dictSize = cctx->cdict->dictContentSize;
+    } else if (cctx->prefixDict.dictSize) {
+        dictSize = cctx->prefixDict.dictSize;
+    } else {
+        dictSize = 0;
+    }
+
+    ZSTD_memcpy(updatedRepcodes.rep, cctx->blockState.prevCBlock->rep, sizeof(repcodes_t));
+    for (; (inSeqs[idx].matchLength != 0 || inSeqs[idx].offset != 0) && idx < inSeqsSize; ++idx) {
+        litLength = inSeqs[idx].litLength;
+        matchLength = inSeqs[idx].matchLength;
+        ll0 = litLength == 0;
+        offCode = ZSTD_finalizeOffCode(inSeqs[idx].offset, updatedRepcodes.rep, ll0);
+        updatedRepcodes = ZSTD_updateRep(updatedRepcodes.rep, offCode, ll0);
+
+        DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offCode, matchLength, litLength);
+#if defined(DEBUGLEVEL) && (DEBUGLEVEL >= 6)
+        seqPos->posInSrc += litLength;
+        FORWARD_IF_ERROR(ZSTD_validateSequence(offCode, repCode, matchLength,
+                                               seqPos->posInSrc, cctx->appliedParams.cParams.windowLog,
+                                               dictSize),
+                         "Sequence validation failed");
+        seqPos->posInSrc += matchLength;
+#endif
+        ZSTD_storeSeq(seqStore, litLength, ip, iend, offCode, matchLength - MINMATCH);
+        ip += matchLength + litLength;
+    }
+    ZSTD_memcpy(cctx->blockState.nextCBlock->rep, updatedRepcodes.rep, sizeof(repcodes_t));
+
+    if (inSeqs[idx].litLength) {
+        DEBUGLOG(6, "Storing last literals of size: %u", inSeqs[idx].litLength);
+        ZSTD_storeLastLiterals(seqStore, ip, inSeqs[idx].litLength);
+        ip += inSeqs[idx].litLength;
+#if defined(DEBUGLEVEL) && (DEBUGLEVEL >= 6)
+        seqPos->posInSrc += inSeqs[idx].litLength;
+#endif
+    }
+    RETURN_ERROR_IF(ip != iend, corruption_detected, "Blocksize doesn't agree with block delimiter!");
+    seqPos->idx = idx+1;
+    return 0;
+}
+
+/* Returns the number of bytes to move the current read position back by. Only non-zero
+ * if we ended up splitting a sequence.
+ */
 static size_t ZSTD_copySequencesToSeqStore(seqStore_t* seqStore, ZSTD_sequencePosition* seqPos,
                                            const ZSTD_Sequence* const inSeqs, size_t inSeqsSize,
                                            const void* src, size_t blockSize, ZSTD_CCtx* cctx) {
@@ -4486,27 +4584,37 @@ static size_t ZSTD_copySequencesToSeqStore(seqStore_t* seqStore, ZSTD_sequencePo
     size_t startPosInSequence = seqPos->posInSequence;
     size_t endPosInSequence = seqPos->posInSequence + blockSize;
     BYTE const* ip = (BYTE const*)(src);
-    const BYTE* const iend = ip + blockSize;
-    U32 windowSize = 1 << cctx->appliedParams.cParams.windowLog;
+    BYTE const* iend = ip + blockSize;
     repcodes_t updatedRepcodes;
     U32 bytesAdjustment = 0;
-    U32 bytesread = 0;
-
-    DEBUGLOG(3, "ZSTD_copySequencesToSeqStore: idx: %zu PIS: %u blockSize: %zu windowSize: %u", idx, startPosInSequence, blockSize, windowSize);
-    DEBUGLOG(3, "start seq: idx: %u (of: %u ml: %u ll: %u)", idx, inSeqs[idx].offset, inSeqs[idx].matchLength, inSeqs[idx].litLength);
+    U32 finalMatchSplit = 0;
+    U32 dictSize;
+    U32 litLength;
+    U32 matchLength;
+    U32 rawOffset;
+    U32 offCode;
+    U32 repCode;
+    
+    if (cctx->cdict) {
+        dictSize = ZSTD_sizeof_CDict(cctx->cdict);
+    } else if (cctx->prefixDict.dictSize) {
+        dictSize = cctx->prefixDict.dictSize;
+    } else if (ZSTD_sizeof_localDict(cctx->localDict)) {
+        dictSize = ZSTD_sizeof_localDict(cctx->localDict);
+    }
+    DEBUGLOG(5, "ZSTD_copySequencesToSeqStore: idx: %zu PIS: %u blockSize: %zu windowSize: %u", idx, startPosInSequence, blockSize, windowSize);
+    DEBUGLOG(5, "Start seq: idx: %u (of: %u ml: %u ll: %u)", idx, inSeqs[idx].offset, inSeqs[idx].matchLength, inSeqs[idx].litLength);
     ZSTD_memcpy(updatedRepcodes.rep, cctx->blockState.prevCBlock->rep, sizeof(repcodes_t));
-    while (endPosInSequence && idx < inSeqsSize) {
-        ZSTD_Sequence currSeq = inSeqs[idx];
-        U32 litLength = currSeq.litLength;
-        U32 matchLength = currSeq.matchLength;
-        U32 rawOffset = currSeq.offset;
-        U32 offCode = rawOffset + ZSTD_REP_MOVE;
-        U32 repCode = cctx->calculateRepcodes ? 0 : currSeq.rep;
+    while (endPosInSequence && idx < inSeqsSize && !finalMatchSplit) {
+        const ZSTD_Sequence currSeq = inSeqs[idx];
+        litLength = currSeq.litLength;
+        matchLength = currSeq.matchLength;
+        rawOffset = currSeq.offset;
+        repCode =  0;
 
         /* Modify the sequence depending on where endPosInSequence lies */
         if (endPosInSequence >= currSeq.litLength + currSeq.matchLength) {
             if (startPosInSequence >= litLength) {
-                /* Start pos is within the match */
                 startPosInSequence -= litLength;
                 litLength = 0;
                 matchLength -= startPosInSequence;
@@ -4519,15 +4627,14 @@ static size_t ZSTD_copySequencesToSeqStore(seqStore_t* seqStore, ZSTD_sequencePo
             idx++;
         } else {
             /* This is the final sequence we're adding from inSeqs, and endPosInSequence
-               does not reach the end of the match. So, we have to split something */
-            DEBUGLOG(3, "Require a split: diff: %u, idx: %u PIS: %u", currSeq.litLength + currSeq.matchLength - endPosInSequence, idx, endPosInSequence);
-            DEBUGLOG(3, "(of: %u ml: %u ll: %u)", rawOffset, matchLength, litLength);
+               does not reach the end of the match. So, we have to split the sequence */
+            DEBUGLOG(6, "Require a split: diff: %u, idx: %u PIS: %u", currSeq.litLength + currSeq.matchLength - endPosInSequence, idx, endPosInSequence);
             if (endPosInSequence > litLength) {
-                DEBUGLOG(2, "hard case");
-                /* This sequence ends inside the match, may need to split match */
-                U32 firstHalfMatchLength = endPosInSequence - litLength;
+                litLength = startPosInSequence >= litLength ? 0 : litLength - startPosInSequence;
+                U32 firstHalfMatchLength = endPosInSequence - startPosInSequence - litLength;
                 if (matchLength > blockSize && firstHalfMatchLength >= MINMATCH) {
-                    U32 secondHalfMatchLength = matchLength - firstHalfMatchLength;
+                    /* Only ever split the match if it is larger than the block size */
+                    U32 secondHalfMatchLength = currSeq.matchLength + currSeq.litLength - endPosInSequence;
                     if (secondHalfMatchLength < MINMATCH) {
                         /* Move the endPosInSequence backward so that it creates match of MINMATCH length */
                         endPosInSequence -= MINMATCH - secondHalfMatchLength;
@@ -4535,50 +4642,40 @@ static size_t ZSTD_copySequencesToSeqStore(seqStore_t* seqStore, ZSTD_sequencePo
                         firstHalfMatchLength -= bytesAdjustment;
                     }
                     matchLength = firstHalfMatchLength;
+                    /* Flag that we split the last match - after storing the sequence, exit the loop,
+                       but keep the value of endPosInSequence */
+                    finalMatchSplit = 1;
                 } else {
-                    /* Move the position in sequence backwards so that we don't split match, and store
-                       the last literals */
-                    DEBUGLOG(2, "MOVING SEQ BACKWARDS");
-                    bytesAdjustment = endPosInSequence - litLength;
-                    endPosInSequence = litLength;
+                    /* Move the position in sequence backwards so that we don't split match, and break to store
+                     * the last literals. We use the original currSeq.litLength as a marker for where endPosInSequence
+                     * should go.
+                     */
+                    bytesAdjustment = endPosInSequence - currSeq.litLength;
+                    endPosInSequence = currSeq.litLength;
                     break;
                 }
             } else {
-                /* This sequence ends inside the literals, store the last literals */
-                litLength = startPosInSequence >= litLength ? 0 : endPosInSequence - startPosInSequence;
+                /* This sequence ends inside the literals, break to store the last literals */
                 break;
             }
         }
-
-        if (matchLength < MINMATCH) {
-            DEBUGLOG(2, "match too small");
-            DEBUGLOG(2, "%u (of: %u ml: %u ll: %u)", idx, rawOffset, matchLength, litLength);
-            RETURN_ERROR_IF(1, corruption_detected, "match");
-        }
-        if (cctx->calculateRepcodes == ZSTD_sf_calculateRepcodes) {
-            U32 ll0 = (litLength == 0);
-            /* Check if current offset matches anything in the repcode table */
-            if (!ll0 && rawOffset == updatedRepcodes.rep[0]) {
-                repCode = 1;
-            } else if (rawOffset == updatedRepcodes.rep[1]) {
-                repCode = 2 - ll0;
-            } else if (rawOffset == updatedRepcodes.rep[2]) {
-                repCode = 3 - ll0;
-            } else if (ll0 && rawOffset == updatedRepcodes.rep[0] - 1) {
-                repCode = 3;
-            }
-            if (repCode) {
-                offCode = repCode - 1;
-            }
+        /* Check if this offset can be represented with a repcode */
+        {   U32 ll0 = (litLength == 0);
+            offCode = ZSTD_finalizeOffCode(rawOffset, updatedRepcodes.rep, ll0);
             updatedRepcodes = ZSTD_updateRep(updatedRepcodes.rep, offCode, ll0);
-        } else {
-            offCode = repCode ? repCode-1 : offCode;
         }
-        DEBUGLOG(6, "Storing: idx: %zu (of: %u, ml: %u, ll: %u)", idx, offCode, matchLength, litLength);
-        /* Validate the offCode */
+
+#if defined(DEBUGLEVEL) && (DEBUGLEVEL >= 6)
+        seqPos->posInSrc += litLength;
+        FORWARD_IF_ERROR(ZSTD_validateSequence(offCode, repCode,
+                                               matchLength, seqPos->posInSrc,
+                                               cctx->appliedParams.cParams.windowLog, dictSize),
+                         "Sequence validation failed");
+        seqPos->posInSrc += matchLength;
+#endif
+        DEBUGLOG(6, "Storing sequence: (of: %u, ml: %u, ll: %u)", offCode, matchLength, litLength);
         ZSTD_storeSeq(seqStore, litLength, ip, iend, offCode, matchLength - MINMATCH);
         ip += matchLength + litLength;
-        bytesread += matchLength + litLength;
     }
     assert(endPosInSequence <= inSeqs[idx].litLength + inSeqs[idx].matchLength);
     seqPos->idx = idx;
@@ -4586,13 +4683,18 @@ static size_t ZSTD_copySequencesToSeqStore(seqStore_t* seqStore, ZSTD_sequencePo
     /* Update repcodes */
     ZSTD_memcpy(cctx->blockState.nextCBlock->rep, updatedRepcodes.rep, sizeof(repcodes_t));
 
-    /* Store any last literals for ZSTD_sf_noBlockDelimiters mode */
-    if (cctx->blockDelimiters == ZSTD_sf_noBlockDelimiters && ip != iend) {
+    iend -= bytesAdjustment;
+    if (ip != iend) {
+        /* Store any last literals */
         U32 lastLLSize = (U32)(iend - ip);
         assert(ip <= iend);
-        DEBUGLOG(2, "Storing last literals of size: %u", lastLLSize);
+        DEBUGLOG(6, "Storing last literals of size: %u", lastLLSize);
         ZSTD_storeLastLiterals(seqStore, ip, lastLLSize);
+#if defined(DEBUGLEVEL) && (DEBUGLEVEL >= 6)
+        seqPos->posInSrc += lastLLSize;
+#endif
     }
+
     return bytesAdjustment;
 }
 
@@ -4609,7 +4711,7 @@ static size_t ZSTD_compressSequences_internal(void* dst, size_t dstCapacity,
     U32 blockSize;
     U32 compressedSeqsSize;
     size_t remaining = srcSize;
-    ZSTD_sequencePosition seqPos = {0, 0};
+    ZSTD_sequencePosition seqPos = {0, 0, 0};
     seqStore_t blockSeqStore;
     
     BYTE const* ip = (BYTE const*)src;
@@ -4634,8 +4736,15 @@ static size_t ZSTD_compressSequences_internal(void* dst, size_t dstCapacity,
         blockSeqStore = cctx->seqStore;
         ZSTD_resetSeqStore(&blockSeqStore);
         DEBUGLOG(4, "Working on new block. Blocksize: %u", blockSize);
-
-        additionalByteAdjustment = ZSTD_copySequencesToSeqStore(&blockSeqStore, &seqPos, inSeqs, inSeqsSize, ip, blockSize, cctx);
+        if (cctx->appliedParams.blockDelimiters == ZSTD_sf_noBlockDelimiters) {
+            additionalByteAdjustment = ZSTD_copySequencesToSeqStore(&blockSeqStore, &seqPos,
+                                                                    inSeqs, inSeqsSize,
+                                                                    ip, blockSize, cctx);
+        } else {
+            additionalByteAdjustment = ZSTD_copySequencesToSeqStoreBlockDelim(&blockSeqStore, &seqPos,
+                                                                    inSeqs, inSeqsSize,
+                                                                    ip, blockSize, cctx);
+        }
         FORWARD_IF_ERROR(additionalByteAdjustment, "Bad sequence copy");
         blockSize -= additionalByteAdjustment;
         /* If blocks are too small, emit as a nocompress block */
@@ -4723,7 +4832,6 @@ size_t ZSTD_compressSequences(ZSTD_CCtx* const cctx, void* dst, size_t dstCapaci
     DEBUGLOG(3, "ZSTD_compressSequences()");
     assert(cctx != NULL);
     FORWARD_IF_ERROR(ZSTD_CCtx_init_compressStream2(cctx, ZSTD_e_end, srcSize), "CCtx initialization failed");
-
     /* Begin writing output, starting with frame header */
     frameHeaderSize = ZSTD_writeFrameHeader(op, dstCapacity, &cctx->appliedParams, srcSize, cctx->dictID);
     op += frameHeaderSize;