]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
changed (partially) the decodeSequences flow logic
authorYann Collet <cyan@fb.com>
Fri, 16 Jun 2023 18:56:22 +0000 (11:56 -0700)
committerYann Collet <cyan@fb.com>
Fri, 16 Jun 2023 18:57:12 +0000 (11:57 -0700)
this allows detecting overflow events without a checksum.

lib/decompress/zstd_decompress_block.c

index 93947ba584f362983b43d1f5f15d54b09f8d4c6b..2d02200434307dacce9235c1c1ab93e24eccdd71 100644 (file)
@@ -1214,14 +1214,20 @@ ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, U16
 
 typedef enum { ZSTD_lo_isRegularOffset, ZSTD_lo_isLongOffset=1 } ZSTD_longOffset_e;
 
+/**
+ * ZSTD_decodeSequence_old():
+ * @p longOffsets : tells the decoder to reload more bit while decoding large offsets
+ *                  only used in 32-bit mode
+ * @return : Sequence (litL + matchL + offset)
+ */
 FORCE_INLINE_TEMPLATE seq_t
-ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
+ZSTD_decodeSequence_old(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
 {
     seq_t seq;
     /*
-     * ZSTD_seqSymbol is a structure with a total of 64 bits wide. So it can be
-     * loaded in one operation and extracted its fields by simply shifting or
-     * bit-extracting on aarch64.
+     * ZSTD_seqSymbol is a 64 bits wide structure.
+     * It can be loaded in one operation
+     * and its fields extracted by simply shifting or bit-extracting on aarch64.
      * GCC doesn't recognize this and generates more unnecessary ldr/ldrb/ldrh
      * operations that cause performance drop. This can be avoided by using this
      * ZSTD_memcpy hack.
@@ -1330,6 +1336,132 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
     return seq;
 }
 
+/**
+ * ZSTD_decodeSequence():
+ * @p longOffsets : tells the decoder to reload more bit while decoding large offsets
+ *                  only used in 32-bit mode
+ * @return : Sequence (litL + matchL + offset)
+ */
+FORCE_INLINE_TEMPLATE seq_t
+ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets, const int isLastSeq)
+{
+    seq_t seq;
+    /*
+     * ZSTD_seqSymbol is a 64 bits wide structure.
+     * It can be loaded in one operation
+     * and its fields extracted by simply shifting or bit-extracting on aarch64.
+     * GCC doesn't recognize this and generates more unnecessary ldr/ldrb/ldrh
+     * operations that cause performance drop. This can be avoided by using this
+     * ZSTD_memcpy hack.
+     */
+#if defined(__aarch64__) && (defined(__GNUC__) && !defined(__clang__))
+    ZSTD_seqSymbol llDInfoS, mlDInfoS, ofDInfoS;
+    ZSTD_seqSymbol* const llDInfo = &llDInfoS;
+    ZSTD_seqSymbol* const mlDInfo = &mlDInfoS;
+    ZSTD_seqSymbol* const ofDInfo = &ofDInfoS;
+    ZSTD_memcpy(llDInfo, seqState->stateLL.table + seqState->stateLL.state, sizeof(ZSTD_seqSymbol));
+    ZSTD_memcpy(mlDInfo, seqState->stateML.table + seqState->stateML.state, sizeof(ZSTD_seqSymbol));
+    ZSTD_memcpy(ofDInfo, seqState->stateOffb.table + seqState->stateOffb.state, sizeof(ZSTD_seqSymbol));
+#else
+    const ZSTD_seqSymbol* const llDInfo = seqState->stateLL.table + seqState->stateLL.state;
+    const ZSTD_seqSymbol* const mlDInfo = seqState->stateML.table + seqState->stateML.state;
+    const ZSTD_seqSymbol* const ofDInfo = seqState->stateOffb.table + seqState->stateOffb.state;
+#endif
+    seq.matchLength = mlDInfo->baseValue;
+    seq.litLength = llDInfo->baseValue;
+    {   U32 const ofBase = ofDInfo->baseValue;
+        BYTE const llBits = llDInfo->nbAdditionalBits;
+        BYTE const mlBits = mlDInfo->nbAdditionalBits;
+        BYTE const ofBits = ofDInfo->nbAdditionalBits;
+        BYTE const totalBits = llBits+mlBits+ofBits;
+
+        U16 const llNext = llDInfo->nextState;
+        U16 const mlNext = mlDInfo->nextState;
+        U16 const ofNext = ofDInfo->nextState;
+        U32 const llnbBits = llDInfo->nbBits;
+        U32 const mlnbBits = mlDInfo->nbBits;
+        U32 const ofnbBits = ofDInfo->nbBits;
+
+        assert(llBits <= MaxLLBits);
+        assert(mlBits <= MaxMLBits);
+        assert(ofBits <= MaxOff);
+        /*
+         * As gcc has better branch and block analyzers, sometimes it is only
+         * valuable to mark likeliness for clang, it gives around 3-4% of
+         * performance.
+         */
+
+        /* sequence */
+        {   size_t offset;
+            if (ofBits > 1) {
+                ZSTD_STATIC_ASSERT(ZSTD_lo_isLongOffset == 1);
+                ZSTD_STATIC_ASSERT(LONG_OFFSETS_MAX_EXTRA_BITS_32 == 5);
+                ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 > LONG_OFFSETS_MAX_EXTRA_BITS_32);
+                ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 - LONG_OFFSETS_MAX_EXTRA_BITS_32 >= MaxMLBits);
+                if (MEM_32bits() && longOffsets && (ofBits >= STREAM_ACCUMULATOR_MIN_32)) {
+                    /* Always read extra bits, this keeps the logic simple,
+                     * avoids branches, and avoids accidentally reading 0 bits.
+                     */
+                    U32 const extraBits = LONG_OFFSETS_MAX_EXTRA_BITS_32;
+                    offset = ofBase + (BIT_readBitsFast(&seqState->DStream, ofBits - extraBits) << extraBits);
+                    BIT_reloadDStream(&seqState->DStream);
+                    offset += BIT_readBitsFast(&seqState->DStream, extraBits);
+                } else {
+                    offset = ofBase + BIT_readBitsFast(&seqState->DStream, ofBits/*>0*/);   /* <=  (ZSTD_WINDOWLOG_MAX-1) bits */
+                    if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream);
+                }
+                seqState->prevOffset[2] = seqState->prevOffset[1];
+                seqState->prevOffset[1] = seqState->prevOffset[0];
+                seqState->prevOffset[0] = offset;
+            } else {
+                U32 const ll0 = (llDInfo->baseValue == 0);
+                if (LIKELY((ofBits == 0))) {
+                    offset = seqState->prevOffset[ll0];
+                    seqState->prevOffset[1] = seqState->prevOffset[!ll0];
+                    seqState->prevOffset[0] = offset;
+                } else {
+                    offset = ofBase + ll0 + BIT_readBitsFast(&seqState->DStream, 1);
+                    {   size_t temp = (offset==3) ? seqState->prevOffset[0] - 1 : seqState->prevOffset[offset];
+                        temp += !temp;   /* 0 is not valid; input is corrupted; force offset to 1 */
+                        if (offset != 1) seqState->prevOffset[2] = seqState->prevOffset[1];
+                        seqState->prevOffset[1] = seqState->prevOffset[0];
+                        seqState->prevOffset[0] = offset = temp;
+            }   }   }
+            seq.offset = offset;
+        }
+
+        if (mlBits > 0)
+            seq.matchLength += BIT_readBitsFast(&seqState->DStream, mlBits/*>0*/);
+
+        if (MEM_32bits() && (mlBits+llBits >= STREAM_ACCUMULATOR_MIN_32-LONG_OFFSETS_MAX_EXTRA_BITS_32))
+            BIT_reloadDStream(&seqState->DStream);
+        if (MEM_64bits() && UNLIKELY(totalBits >= STREAM_ACCUMULATOR_MIN_64-(LLFSELog+MLFSELog+OffFSELog)))
+            BIT_reloadDStream(&seqState->DStream);
+        /* Ensure there are enough bits to read the rest of data in 64-bit mode. */
+        ZSTD_STATIC_ASSERT(16+LLFSELog+MLFSELog+OffFSELog < STREAM_ACCUMULATOR_MIN_64);
+
+        if (llBits > 0)
+            seq.litLength += BIT_readBitsFast(&seqState->DStream, llBits/*>0*/);
+
+        if (MEM_32bits())
+            BIT_reloadDStream(&seqState->DStream);
+
+        DEBUGLOG(6, "seq: litL=%u, matchL=%u, offset=%u",
+                    (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset);
+
+        if (!isLastSeq) {
+            /* don't update FSE state for last Sequence */
+            ZSTD_updateFseStateWithDInfo(&seqState->stateLL, &seqState->DStream, llNext, llnbBits);    /* <=  9 bits */
+            ZSTD_updateFseStateWithDInfo(&seqState->stateML, &seqState->DStream, mlNext, mlnbBits);    /* <=  9 bits */
+            if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream);    /* <= 18 bits */
+            ZSTD_updateFseStateWithDInfo(&seqState->stateOffb, &seqState->DStream, ofNext, ofnbBits);  /* <=  8 bits */
+            BIT_reloadDStream(&seqState->DStream);
+        }
+    }
+
+    return seq;
+}
+
 #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
 MEM_STATIC int ZSTD_dictionaryIsActive(ZSTD_DCtx const* dctx, BYTE const* prefixStart, BYTE const* oLitEnd)
 {
@@ -1420,7 +1552,7 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx,
 
         /* decompress without overrunning litPtr begins */
         {
-            seq_t sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
+            seq_t sequence = ZSTD_decodeSequence_old(&seqState, isLongOffset);
             /* Align the decompression loop to 32 + 16 bytes.
                 *
                 * zstd compiled with gcc-9 on an Intel i9-9900k shows 10% decompression
@@ -1495,7 +1627,7 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx,
                 if (UNLIKELY(!--nbSeq))
                     break;
                 BIT_reloadDStream(&(seqState.DStream));
-                sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
+                sequence = ZSTD_decodeSequence_old(&seqState, isLongOffset);
             }
 
             /* If there are more sequences, they will need to read literals from litExtraBuffer; copy over the remainder from dst and update litPtr and litEnd */
@@ -1548,7 +1680,7 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx,
 #endif
 
             for (; ; ) {
-                seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
+                seq_t const sequence = ZSTD_decodeSequence_old(&seqState, isLongOffset);
                 size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd);
 #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE)
                 assert(!ZSTD_isError(oneSeqSize));
@@ -1647,8 +1779,8 @@ ZSTD_decompressSequences_body(ZSTD_DCtx* dctx,
 #  endif
 #endif
 
-        for ( ; ) {
-            seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
+        for ( ; nbSeq ; nbSeq--) {
+            seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, nbSeq==1);
             size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litEnd, prefixStart, vBase, dictEnd);
 #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE)
             assert(!ZSTD_isError(oneSeqSize));
@@ -1658,15 +1790,13 @@ ZSTD_decompressSequences_body(ZSTD_DCtx* dctx,
                 return oneSeqSize;
             DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize);
             op += oneSeqSize;
-            if (UNLIKELY(!--nbSeq))
-                break;
-            BIT_reloadDStream(&(seqState.DStream));
         }
 
         /* check if reached exact end */
         DEBUGLOG(5, "ZSTD_decompressSequences_body: after decode loop, remaining nbSeq : %i", nbSeq);
         RETURN_ERROR_IF(nbSeq, corruption_detected, "");
-        RETURN_ERROR_IF(BIT_reloadDStream(&seqState.DStream) < BIT_DStream_completed, corruption_detected, "");
+        DEBUGLOG(5, "bitStream : start=%p, ptr=%p, bitsConsumed=%u", seqState.DStream.start, seqState.DStream.ptr, seqState.DStream.bitsConsumed);
+        RETURN_ERROR_IF(!BIT_endOfDStream(&seqState.DStream), corruption_detected, "");
         /* save reps for next block */
         { U32 i; for (i=0; i<ZSTD_REP_NUM; i++) dctx->entropy.rep[i] = (U32)(seqState.prevOffset[i]); }
     }
@@ -1763,7 +1893,7 @@ ZSTD_decompressSequencesLong_body(
 
         /* prepare in advance */
         for (seqNb=0; (BIT_reloadDStream(&seqState.DStream) <= BIT_DStream_completed) && (seqNb<seqAdvance); seqNb++) {
-            seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
+            seq_t const sequence = ZSTD_decodeSequence_old(&seqState, isLongOffset);
             prefetchPos = ZSTD_prefetchMatch(prefetchPos, sequence, prefixStart, dictEnd);
             sequences[seqNb] = sequence;
         }
@@ -1771,7 +1901,7 @@ ZSTD_decompressSequencesLong_body(
 
         /* decompress without stomping litBuffer */
         for (; (BIT_reloadDStream(&(seqState.DStream)) <= BIT_DStream_completed) && (seqNb < nbSeq); seqNb++) {
-            seq_t sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
+            seq_t sequence = ZSTD_decodeSequence_old(&seqState, isLongOffset);
             size_t oneSeqSize;
 
             if (dctx->litBufferLocation == ZSTD_split && litPtr + sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK].litLength > dctx->litBufferEnd)