]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
switchable bit-approximation / fractional-bit accuracy modes
authorYann Collet <cyan@fb.com>
Wed, 9 May 2018 17:48:09 +0000 (10:48 -0700)
committerYann Collet <cyan@fb.com>
Wed, 9 May 2018 17:48:09 +0000 (10:48 -0700)
also : makes it possible to select nb of fractional bits.

lib/common/fse.h
lib/compress/zstd_opt.c

index 677078558e4740ddba3e8c1f978508f87c84460b..8e44c1a444ed5a1be49c299e42bfedcbfbf30538 100644 (file)
@@ -582,16 +582,20 @@ MEM_STATIC U32 FSE_getMaxNbBits(const FSE_symbolCompressionTransform* symbolTT,
 
 /* FSE_bitCost_b256() :
  * Approximate symbol cost,
- * provide fractional value, using fixed-point format (8 bit) */
-MEM_STATIC U32 FSE_bitCost_b256(const FSE_symbolCompressionTransform* symbolTT, U32 tableLog, U32 symbolValue)
+ * provide fractional value, using fixed-point format (accuracyLog fractional bits) */
+MEM_STATIC U32 FSE_bitCost(const FSE_symbolCompressionTransform* symbolTT, U32 tableLog, U32 symbolValue, U32 accuracyLog)
 {
     U32 const minNbBits = symbolTT[symbolValue].deltaNbBits >> 16;
     U32 const threshold = (minNbBits+1) << 16;
-    assert(symbolTT[symbolValue].deltaNbBits + (1<<tableLog) <= threshold);
-    U32 const deltaFromThreshold = threshold - (symbolTT[symbolValue].deltaNbBits + (1 << tableLog));
-    U32 const normalizedDeltaFromThreshold = (deltaFromThreshold << 8) >> tableLog;   /* linear interpolation (very approximate) */
-    assert(normalizedDeltaFromThreshold <= 256);
-    return (minNbBits+1)*256 - normalizedDeltaFromThreshold;
+    assert(tableLog < 16);
+    U32 const tableSize = 1 << tableLog;
+    assert(symbolTT[symbolValue].deltaNbBits + tableSize <= threshold);
+    U32 const deltaFromThreshold = threshold - (symbolTT[symbolValue].deltaNbBits + tableSize);
+    assert(accuracyLog < 31-tableLog);  /* ensure enough room for renormalization double shift */
+    U32 const normalizedDeltaFromThreshold = (deltaFromThreshold << accuracyLog) >> tableLog;   /* linear interpolation (very approximate) */
+    U32 const bitMultiplier = 1 << accuracyLog;
+    assert(normalizedDeltaFromThreshold <= bitMultiplier);
+    return (minNbBits+1)*bitMultiplier - normalizedDeltaFromThreshold;
 }
 
 
index 80edb1a7b7293093f82f8a3ffae3cf9b9f1757df..67db85eb97e4930e7b01f24fbe832be77d57e49a 100644 (file)
@@ -91,6 +91,15 @@ static void ZSTD_rescaleFreqs(optState_t* const optPtr,
     ZSTD_setLog2Prices(optPtr);
 }
 
+#if 1   /* approximation at bit level */
+#  define BITCOST_ACCURACY 0
+#  define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY)
+#  define BITCOST_SYMBOL(t,l,s)  ((void)l, FSE_getMaxNbBits(t,s)*BITCOST_MULTIPLIER)
+#else   /* fractional bit accuracy */
+#  define BITCOST_ACCURACY 8
+#  define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY)
+#  define BITCOST_SYMBOL(t,l,s)  FSE_bitCost(t,l,s,BITCOST_ACCURACY)
+#endif
 
 /* ZSTD_rawLiteralsCost() :
  * cost of literals (only) in specified segment (which length can be 0).
@@ -98,23 +107,23 @@ static void ZSTD_rescaleFreqs(optState_t* const optPtr,
 static U32 ZSTD_rawLiteralsCost(const BYTE* const literals, U32 const litLength,
                                 const optState_t* const optPtr)
 {
+    if (litLength == 0) return 0;
+    if (optPtr->priceType == zop_predef) return (litLength*6);  /* 6 bit per literal - no statistic used */
     if (optPtr->priceType == zop_static) {
         U32 u, cost;
         assert(optPtr->symbolCosts != NULL);
         assert(optPtr->symbolCosts->hufCTable_repeatMode == HUF_repeat_valid);
         for (u=0, cost=0; u < litLength; u++)
             cost += HUF_getNbBits(optPtr->symbolCosts->hufCTable, literals[u]);
-        return cost << 8;
+        return cost * BITCOST_MULTIPLIER;
     }
-    if (optPtr->priceType == zop_predef) return (litLength*6);  /* 6 bit per literal - no statistic used */
-    if (litLength == 0) return 0;
 
-    /* literals */
+    /* dynamic statistics */
     {   U32 u;
         U32 cost = litLength * optPtr->log2litSum;
         for (u=0; u < litLength; u++)
             cost -= ZSTD_highbit32(optPtr->litFreq[literals[u]]+1);
-        return cost << 8;
+        return cost * BITCOST_MULTIPLIER;
     }
 }
 
@@ -126,15 +135,15 @@ static U32 ZSTD_litLengthPrice(U32 const litLength, const optState_t* const optP
         U32 const llCode = ZSTD_LLcode(litLength);
         FSE_CState_t cstate;
         FSE_initCState(&cstate, optPtr->symbolCosts->litlengthCTable);
-        U32 const price = LL_bits[llCode]*256 + FSE_bitCost_b256(cstate.symbolTT, cstate.stateLog, llCode);
-        DEBUGLOG(8, "ZSTD_litLengthPrice: ll=%u, bitCost=%.2f", litLength, (double)price / 256);
+        U32 const price = LL_bits[llCode]*BITCOST_MULTIPLIER + BITCOST_SYMBOL(cstate.symbolTT, cstate.stateLog, llCode);
+        DEBUGLOG(8, "ZSTD_litLengthPrice: ll=%u, bitCost=%.2f", litLength, (double)price / BITCOST_MULTIPLIER);
         return price;
     }
     if (optPtr->priceType == zop_predef) return ZSTD_highbit32((U32)litLength+1);
 
     /* dynamic statistics */
     {   U32 const llCode = ZSTD_LLcode(litLength);
-        return (LL_bits[llCode] + optPtr->log2litLengthSum - ZSTD_highbit32(optPtr->litLengthFreq[llCode]+1)) << 8;
+        return (LL_bits[llCode] + optPtr->log2litLengthSum - ZSTD_highbit32(optPtr->litLengthFreq[llCode]+1)) * BITCOST_MULTIPLIER;
     }
 }
 
@@ -158,18 +167,18 @@ static int ZSTD_litLengthContribution(U32 const litLength, const optState_t* con
         U32 const llCode = ZSTD_LLcode(litLength);
         FSE_CState_t cstate;
         FSE_initCState(&cstate, optPtr->symbolCosts->litlengthCTable);
-        return (int)(LL_bits[llCode] * 256)
-             + FSE_bitCost_b256(cstate.symbolTT, cstate.stateLog, llCode)
-             - FSE_bitCost_b256(cstate.symbolTT, cstate.stateLog, 0);
+        return (int)(LL_bits[llCode] * BITCOST_MULTIPLIER)
+             + BITCOST_SYMBOL(cstate.symbolTT, cstate.stateLog, llCode)
+             - BITCOST_SYMBOL(cstate.symbolTT, cstate.stateLog, 0);
     }
     if (optPtr->priceType >= zop_predef) return ZSTD_highbit32(litLength+1);
 
-    /* literal Length */
+    /* dynamic statistics */
     {   U32 const llCode = ZSTD_LLcode(litLength);
         int const contribution = (LL_bits[llCode]
                         + ZSTD_highbit32(optPtr->litLengthFreq[0]+1)
                         - ZSTD_highbit32(optPtr->litLengthFreq[llCode]+1))
-                        * 256;
+                        * BITCOST_MULTIPLIER;
 #if 1
         return contribution;
 #else
@@ -209,13 +218,14 @@ ZSTD_getMatchPrice(U32 const offset, U32 const matchLength,
         FSE_CState_t mlstate, offstate;
         FSE_initCState(&mlstate, optPtr->symbolCosts->matchlengthCTable);
         FSE_initCState(&offstate, optPtr->symbolCosts->offcodeCTable);
-        return FSE_bitCost_b256(offstate.symbolTT, offstate.stateLog, offCode) + offCode*256
-             + FSE_bitCost_b256(mlstate.symbolTT, mlstate.stateLog, mlCode) + ML_bits[mlCode]*256;
+        return BITCOST_SYMBOL(offstate.symbolTT, offstate.stateLog, offCode) + offCode*BITCOST_MULTIPLIER
+             + BITCOST_SYMBOL(mlstate.symbolTT, mlstate.stateLog, mlCode) + ML_bits[mlCode]*BITCOST_MULTIPLIER;
     }
 
     if (optPtr->priceType == zop_predef)  /* fixed scheme, do not use statistics */
         return ZSTD_highbit32(mlBase+1) + 16 + offCode;
 
+    /* dynamic statistics */
     price = offCode + optPtr->log2offCodeSum - ZSTD_highbit32(optPtr->offCodeFreq[offCode]+1);
     if ((optLevel<2) /*static*/ && offCode >= 20) price += (offCode-19)*2; /* handicap for long distance offsets, favor decompression speed */
 
@@ -225,7 +235,7 @@ ZSTD_getMatchPrice(U32 const offset, U32 const matchLength,
     }
 
     DEBUGLOG(8, "ZSTD_getMatchPrice(ml:%u) = %u", matchLength, price);
-    return price << 8;
+    return price * BITCOST_MULTIPLIER;
 }
 
 static void ZSTD_updateStats(optState_t* const optPtr,