]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Optimize ZSTD_decodeSequence by another x%
authorDanila Kutenin <kutdanila@yandex.ru>
Sat, 29 May 2021 17:21:10 +0000 (18:21 +0100)
committerDanila Kutenin <kutdanila@yandex.ru>
Sat, 29 May 2021 17:21:10 +0000 (18:21 +0100)
lib/common/bitstream.h
lib/decompress/zstd_decompress_block.c

index 2e5a933ad3dd15349b2ae114e5fc5757e0923543..1a494ee8567208755f9c1724cbe2ca1d1346946f 100644 (file)
@@ -332,7 +332,17 @@ MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getMiddleBits(size_t bitContainer, U32 c
     U32 const regMask = sizeof(bitContainer)*8 - 1;
     /* if start > regMask, bitstream is corrupted, and result is undefined */
     assert(nbBits < BIT_MASK_SIZE);
+    /* x86 transform & ((1 << nbBits) - 1) to bzhi instruction, it is better
+     * than accessing memory. When bmi2 instruction is not present, we consider
+     * such cpus old (pre-Haswell, 2013) and their performance is not of that
+     * importance.
+     */
+#if defined(__x86_64__) || defined(_M_X86)
+    U64 const one = 1;
+    return (bitContainer >> (start & regMask)) & ((one << nbBits) - 1);
+#else
     return (bitContainer >> (start & regMask)) & BIT_mask[nbBits];
+#endif
 }
 
 MEM_STATIC FORCE_INLINE_ATTR size_t BIT_getLowerBits(size_t bitContainer, U32 const nbBits)
index 349dcdc333627adea2549a2c0e3697fb6bcb1fc3..6c40be4163e76b2d0087b3d6441161071fc80eeb 100644 (file)
@@ -905,20 +905,10 @@ ZSTD_initFseState(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, const ZSTD_seqS
 }
 
 FORCE_INLINE_TEMPLATE void
-ZSTD_updateFseState(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD)
+ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, U16 nextState, U32 nbBits)
 {
-    ZSTD_seqSymbol const DInfo = DStatePtr->table[DStatePtr->state];
-    U32 const nbBits = DInfo.nbBits;
     size_t const lowBits = BIT_readBits(bitD, nbBits);
-    DStatePtr->state = DInfo.nextState + lowBits;
-}
-
-FORCE_INLINE_TEMPLATE void
-ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, ZSTD_seqSymbol const DInfo)
-{
-    U32 const nbBits = DInfo.nbBits;
-    size_t const lowBits = BIT_readBits(bitD, nbBits);
-    DStatePtr->state = DInfo.nextState + lowBits;
+    DStatePtr->state = nextState + lowBits;
 }
 
 /* We need to add at most (ZSTD_WINDOWLOG_MAX_32 - 1) bits to read the maximum
@@ -937,20 +927,36 @@ FORCE_INLINE_TEMPLATE seq_t
 ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
 {
     seq_t seq;
-    ZSTD_seqSymbol const llDInfo = seqState->stateLL.table[seqState->stateLL.state];
-    ZSTD_seqSymbol const mlDInfo = seqState->stateML.table[seqState->stateML.state];
-    ZSTD_seqSymbol const ofDInfo = seqState->stateOffb.table[seqState->stateOffb.state];
-    U32 const llBase = llDInfo.baseValue;
-    U32 const mlBase = mlDInfo.baseValue;
-    U32 const ofBase = ofDInfo.baseValue;
-    BYTE const llBits = llDInfo.nbAdditionalBits;
-    BYTE const mlBits = mlDInfo.nbAdditionalBits;
-    BYTE const ofBits = ofDInfo.nbAdditionalBits;
+    const ZSTD_seqSymbol* const llDInfo = seqState->stateLL.table + seqState->stateLL.state;
+    const ZSTD_seqSymbol* const mlDInfo = seqState->stateML.table + seqState->stateML.state;
+    const ZSTD_seqSymbol* const ofDInfo = seqState->stateOffb.table + seqState->stateOffb.state;
+    seq.matchLength = mlDInfo->baseValue;
+    seq.litLength = llDInfo->baseValue;
+    U32 const ofBase = ofDInfo->baseValue;
+    BYTE const llBits = llDInfo->nbAdditionalBits;
+    BYTE const mlBits = mlDInfo->nbAdditionalBits;
+    BYTE const ofBits = ofDInfo->nbAdditionalBits;
     BYTE const totalBits = llBits+mlBits+ofBits;
 
+    U16 const llNext = llDInfo->nextState;
+    U16 const mlNext = mlDInfo->nextState;
+    U16 const ofNext = ofDInfo->nextState;
+    U32 const llnbBits = llDInfo->nbBits;
+    U32 const mlnbBits = mlDInfo->nbBits;
+    U32 const ofnbBits = ofDInfo->nbBits;
+    /*
+     * As gcc has better branch and block analyzers, sometimes it is only
+     * valuable to mark likelyness for clang, it gives around 3-4% of
+     * performance.
+     */
+
     /* sequence */
     {   size_t offset;
+#if defined(__clang__)
+        if (LIKELY(ofBits > 1)) {
+#else
         if (ofBits > 1) {
+#endif
             ZSTD_STATIC_ASSERT(ZSTD_lo_isLongOffset == 1);
             ZSTD_STATIC_ASSERT(LONG_OFFSETS_MAX_EXTRA_BITS_32 == 5);
             assert(ofBits <= MaxOff);
@@ -968,12 +974,10 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
             seqState->prevOffset[1] = seqState->prevOffset[0];
             seqState->prevOffset[0] = offset;
         } else {
-            U32 const ll0 = (llBase == 0);
+            U32 const ll0 = (llDInfo->baseValue == 0);
             if (LIKELY((ofBits == 0))) {
-                if (LIKELY(!ll0))
-                    offset = seqState->prevOffset[0];
-                else {
-                    offset = seqState->prevOffset[1];
+                offset = seqState->prevOffset[ll0];
+                if (UNLIKELY(ll0)) {
                     seqState->prevOffset[1] = seqState->prevOffset[0];
                     seqState->prevOffset[0] = offset;
                 }
@@ -988,8 +992,11 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
         seq.offset = offset;
     }
 
-    seq.matchLength = mlBase;
+#if defined(__clang__)
+    if (UNLIKELY(mlBits > 0))
+#else
     if (mlBits > 0)
+#endif
         seq.matchLength += BIT_readBitsFast(&seqState->DStream, mlBits/*>0*/);
 
     if (MEM_32bits() && (mlBits+llBits >= STREAM_ACCUMULATOR_MIN_32-LONG_OFFSETS_MAX_EXTRA_BITS_32))
@@ -999,8 +1006,11 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
     /* Ensure there are enough bits to read the rest of data in 64-bit mode. */
     ZSTD_STATIC_ASSERT(16+LLFSELog+MLFSELog+OffFSELog < STREAM_ACCUMULATOR_MIN_64);
 
-    seq.litLength = llBase;
+#if defined(__clang__)
+    if (UNLIKELY(llBits > 0))
+#else
     if (llBits > 0)
+#endif
         seq.litLength += BIT_readBitsFast(&seqState->DStream, llBits/*>0*/);
 
     if (MEM_32bits())
@@ -1009,31 +1019,10 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
     DEBUGLOG(6, "seq: litL=%u, matchL=%u, offset=%u",
                 (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset);
 
-    /* ANS state update
-     * gcc-9.0.0 does 2.5% worse with ZSTD_updateFseStateWithDInfo().
-     * clang-9.2.0 does 7% worse with ZSTD_updateFseState().
-     * Naturally it seems like ZSTD_updateFseStateWithDInfo() should be the
-     * better option, so it is the default for other compilers. But, if you
-     * measure that it is worse, please put up a pull request.
-     */
-    {
-#if defined(__GNUC__) && !defined(__clang__)
-        const int kUseUpdateFseState = 1;
-#else
-        const int kUseUpdateFseState = 0;
-#endif
-        if (kUseUpdateFseState) {
-            ZSTD_updateFseState(&seqState->stateLL, &seqState->DStream);    /* <=  9 bits */
-            ZSTD_updateFseState(&seqState->stateML, &seqState->DStream);    /* <=  9 bits */
-            if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream);    /* <= 18 bits */
-            ZSTD_updateFseState(&seqState->stateOffb, &seqState->DStream);  /* <=  8 bits */
-        } else {
-            ZSTD_updateFseStateWithDInfo(&seqState->stateLL, &seqState->DStream, llDInfo);    /* <=  9 bits */
-            ZSTD_updateFseStateWithDInfo(&seqState->stateML, &seqState->DStream, mlDInfo);    /* <=  9 bits */
-            if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream);    /* <= 18 bits */
-            ZSTD_updateFseStateWithDInfo(&seqState->stateOffb, &seqState->DStream, ofDInfo);  /* <=  8 bits */
-        }
-    }
+    ZSTD_updateFseStateWithDInfo(&seqState->stateLL, &seqState->DStream, llNext, llnbBits);    /* <=  9 bits */
+    ZSTD_updateFseStateWithDInfo(&seqState->stateML, &seqState->DStream, mlNext, mlnbBits);    /* <=  9 bits */
+    if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream);    /* <= 18 bits */
+    ZSTD_updateFseStateWithDInfo(&seqState->stateOffb, &seqState->DStream, ofNext, ofnbBits);  /* <=  8 bits */
 
     return seq;
 }