]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
introduced bit-fractional cost evaluation
authorYann Collet <cyan@fb.com>
Wed, 16 May 2018 21:53:35 +0000 (14:53 -0700)
committerYann Collet <cyan@fb.com>
Wed, 16 May 2018 21:53:35 +0000 (14:53 -0700)
this improves compression ratio by a *tiny* amount.
It also reduces speed by a small amount.

Consequently, bit-fractional evaluation is only turned on for btultra.

lib/compress/zstd_compress_internal.h
lib/compress/zstd_opt.c

index d9edfc31c6f18e54f00d5a8b6dc8ef7e8fe1d321..8e13aadd8e3dfbe18ad79901ca5c0e2f2f20b862 100644 (file)
@@ -91,13 +91,11 @@ typedef struct {
     U32  litLengthSum;           /* nb of litLength codes */
     U32  matchLengthSum;         /* nb of matchLength codes */
     U32  offCodeSum;             /* nb of offset codes */
-    /* begin updated by ZSTD_setLog2Prices */
-    U32  log2litSum;             /* pow2 to compare log2(litfreq) to */
-    U32  log2litLengthSum;       /* pow2 to compare log2(llfreq) to */
-    U32  log2matchLengthSum;     /* pow2 to compare log2(mlfreq) to */
-    U32  log2offCodeSum;         /* pow2 to compare log2(offreq) to */
-    /* end : updated by ZSTD_setLog2Prices */
-    ZSTD_OptPrice_e priceType;   /* prices can be determined dynamically, or follow dictionary statistics, or a pre-defined cost structure */
+    U32  litSumBasePrice;        /* to compare to log2(litfreq) */
+    U32  litLengthSumBasePrice;  /* to compare to log2(llfreq)  */
+    U32  matchLengthSumBasePrice;/* to compare to log2(mlfreq)  */
+    U32  offCodeSumBasePrice;    /* to compare to log2(offreq)  */
+    ZSTD_OptPrice_e priceType;   /* prices can be determined dynamically, or follow a pre-defined cost structure */
     const ZSTD_entropyCTables_t* symbolCosts;  /* pre-calculated dictionary statistics */
 } optState_t;
 
index fa46816119341cded2a9b58d75a235f095a96247..274bce1483491c120b021abd9fa2cb3fb66aa8da 100644 (file)
 /*-*************************************
 *  Price functions for optimal parser
 ***************************************/
-static void ZSTD_setLog2Prices(optState_t* optPtr)
+
+#if 0    /* approximation at bit level */
+#  define BITCOST_ACCURACY 0
+#  define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY)
+#  define WEIGHT(stat)  ((void)opt, ZSTD_bitWeight(stat))
+#elif 0  /* fractional bit accuracy */
+#  define BITCOST_ACCURACY 8
+#  define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY)
+#  define WEIGHT(stat,opt) ((void)opt, ZSTD_fracWeight(stat))
+#else   /* opt==approx, ultra==accurate */
+#  define BITCOST_ACCURACY 8
+#  define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY)
+#  define WEIGHT(stat,opt) (opt ? ZSTD_fracWeight(stat) : ZSTD_bitWeight(stat) )
+#endif
+
+MEM_STATIC U32 ZSTD_bitWeight(U32 stat)
+{
+    return (ZSTD_highbit32((stat)+1) * BITCOST_MULTIPLIER);
+}
+
+MEM_STATIC U32 ZSTD_fracWeight(U32 stat)
+{
+    U32 const hb = stat ? ZSTD_highbit32(stat) : 0;
+    U32 const BWeight = hb * BITCOST_MULTIPLIER;
+    U32 const FWeight = (stat << BITCOST_ACCURACY) >> hb;
+    U32 const weight = BWeight + FWeight;
+    assert(hb + BITCOST_ACCURACY < 31);
+    DEBUGLOG(2, "stat=%u, hb=%u, weight=%u", stat, hb, weight)
+    return weight;
+}
+
+/* debugging function, @return price in bytes */
+MEM_STATIC double ZSTD_fCost(U32 price)
 {
-    optPtr->log2litSum = ZSTD_highbit32(optPtr->litSum+1);
-    optPtr->log2litLengthSum = ZSTD_highbit32(optPtr->litLengthSum+1);
-    optPtr->log2matchLengthSum = ZSTD_highbit32(optPtr->matchLengthSum+1);
-    optPtr->log2offCodeSum = ZSTD_highbit32(optPtr->offCodeSum+1);
+    return (double)price / (BITCOST_MULTIPLIER*8);
+}
+
+static void ZSTD_setBasePrices(optState_t* optPtr, int optLevel)
+{
+    optPtr->litSumBasePrice = WEIGHT(optPtr->litSum, optLevel);
+    optPtr->litLengthSumBasePrice = WEIGHT(optPtr->litLengthSum, optLevel);
+    optPtr->matchLengthSumBasePrice = WEIGHT(optPtr->matchLengthSum, optLevel);
+    optPtr->offCodeSumBasePrice = WEIGHT(optPtr->offCodeSum, optLevel);
 }
 
 
 static void ZSTD_rescaleFreqs(optState_t* const optPtr,
-                              const BYTE* const src, size_t const srcSize)
+                              const BYTE* const src, size_t const srcSize,
+                              int optLevel)
 {
     optPtr->priceType = zop_dynamic;
 
@@ -103,20 +141,20 @@ static void ZSTD_rescaleFreqs(optState_t* const optPtr,
             {   unsigned ll;
                 for (ll=0; ll<=MaxLL; ll++)
                     optPtr->litLengthFreq[ll] = 1;
-                optPtr->litLengthSum = MaxLL+1;
             }
+            optPtr->litLengthSum = MaxLL+1;
 
             {   unsigned ml;
                 for (ml=0; ml<=MaxML; ml++)
                     optPtr->matchLengthFreq[ml] = 1;
-                optPtr->matchLengthSum = MaxML+1;
             }
+            optPtr->matchLengthSum = MaxML+1;
 
             {   unsigned of;
                 for (of=0; of<=MaxOff; of++)
                     optPtr->offCodeFreq[of] = 1;
-                optPtr->offCodeSum = MaxOff+1;
             }
+            optPtr->offCodeSum = MaxOff+1;
 
         }
 
@@ -145,52 +183,37 @@ static void ZSTD_rescaleFreqs(optState_t* const optPtr,
         }
     }
 
-    ZSTD_setLog2Prices(optPtr);
-}
-
-#if 1   /* approximation at bit level */
-#  define BITCOST_ACCURACY 0
-#  define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY)
-#  define BITCOST_SYMBOL(t,l,s)  ((void)(l), FSE_getMaxNbBits(t,s) * BITCOST_MULTIPLIER)
-#else   /* fractional bit accuracy */
-#  define BITCOST_ACCURACY 8
-#  define BITCOST_MULTIPLIER (1 << BITCOST_ACCURACY)
-#  define BITCOST_SYMBOL(t,l,s)  FSE_bitCost(t,l,s,BITCOST_ACCURACY)
-#endif
-
-MEM_STATIC double
-ZSTD_fCost(U32 price)
-{
-    return (double)price / (BITCOST_MULTIPLIER*8);
+    ZSTD_setBasePrices(optPtr, optLevel);
 }
 
 /* ZSTD_rawLiteralsCost() :
- * cost of literals (only) in specified segment (which length can be 0).
- * does not include cost of literalLength symbol */
+ * price of literals (only) in specified segment (which length can be 0).
+ * does not include price of literalLength symbol */
 static U32 ZSTD_rawLiteralsCost(const BYTE* const literals, U32 const litLength,
-                                const optState_t* const optPtr)
+                                const optState_t* const optPtr,
+                                int optLevel)
 {
     if (litLength == 0) return 0;
-    if (optPtr->priceType == zop_predef) return (litLength*6);  /* 6 bit per literal - no statistic used */
+    if (optPtr->priceType == zop_predef) return (litLength*6) * BITCOST_MULTIPLIER;  /* 6 bit per literal - no statistic used */
 
     /* dynamic statistics */
-    {   U32 u;
-        U32 cost = litLength * optPtr->log2litSum;
+    {   U32 price = litLength * optPtr->litSumBasePrice;
+        U32 u;
         for (u=0; u < litLength; u++)
-            cost -= ZSTD_highbit32(optPtr->litFreq[literals[u]]+1);
-        return cost * BITCOST_MULTIPLIER;
+            price -= WEIGHT(optPtr->litFreq[literals[u]], optLevel);
+        return price;
     }
 }
 
 /* ZSTD_litLengthPrice() :
  * cost of literalLength symbol */
-static U32 ZSTD_litLengthPrice(U32 const litLength, const optState_t* const optPtr)
+static U32 ZSTD_litLengthPrice(U32 const litLength, const optState_t* const optPtr, int optLevel)
 {
-    if (optPtr->priceType == zop_predef) return ZSTD_highbit32((U32)litLength+1);
+    if (optPtr->priceType == zop_predef) return WEIGHT(litLength, optLevel);
 
     /* dynamic statistics */
     {   U32 const llCode = ZSTD_LLcode(litLength);
-        return (LL_bits[llCode] + optPtr->log2litLengthSum - ZSTD_highbit32(optPtr->litLengthFreq[llCode]+1)) * BITCOST_MULTIPLIER;
+        return (LL_bits[llCode] * BITCOST_MULTIPLIER) + (optPtr->litLengthSumBasePrice - WEIGHT(optPtr->litLengthFreq[llCode], optLevel));
     }
 }
 
@@ -198,26 +221,26 @@ static U32 ZSTD_litLengthPrice(U32 const litLength, const optState_t* const optP
  * cost of the literal part of a sequence,
  * including literals themselves, and literalLength symbol */
 static U32 ZSTD_fullLiteralsCost(const BYTE* const literals, U32 const litLength,
-                                 const optState_t* const optPtr)
+                                 const optState_t* const optPtr,
+                                 int optLevel)
 {
-    return ZSTD_rawLiteralsCost(literals, litLength, optPtr)
-         + ZSTD_litLengthPrice(litLength, optPtr);
+    return ZSTD_rawLiteralsCost(literals, litLength, optPtr, optLevel)
+         + ZSTD_litLengthPrice(litLength, optPtr, optLevel);
 }
 
 /* ZSTD_litLengthContribution() :
  * @return ( cost(litlength) - cost(0) )
  * this value can then be added to rawLiteralsCost()
  * to provide a cost which is directly comparable to a match ending at same position */
-static int ZSTD_litLengthContribution(U32 const litLength, const optState_t* const optPtr)
+static int ZSTD_litLengthContribution(U32 const litLength, const optState_t* const optPtr, int optLevel)
 {
-    if (optPtr->priceType >= zop_predef) return ZSTD_highbit32(litLength+1);
+    if (optPtr->priceType >= zop_predef) return WEIGHT(litLength, optLevel);
 
     /* dynamic statistics */
     {   U32 const llCode = ZSTD_LLcode(litLength);
-        int const contribution = (LL_bits[llCode]
-                        + ZSTD_highbit32(optPtr->litLengthFreq[0]+1) /* note: log2litLengthSum cancels out with following one */
-                        - ZSTD_highbit32(optPtr->litLengthFreq[llCode]+1))
-                        * BITCOST_MULTIPLIER;
+        int const contribution = (LL_bits[llCode] * BITCOST_MULTIPLIER)
+                               + WEIGHT(optPtr->litLengthFreq[0], optLevel)   /* note: log2litLengthSum cancel out */
+                               - WEIGHT(optPtr->litLengthFreq[llCode], optLevel);
 #if 1
         return contribution;
 #else
@@ -231,10 +254,11 @@ static int ZSTD_litLengthContribution(U32 const litLength, const optState_t* con
  * which can be compared to the ending cost of a match
  * should a new match start at this position */
 static int ZSTD_literalsContribution(const BYTE* const literals, U32 const litLength,
-                                     const optState_t* const optPtr)
+                                     const optState_t* const optPtr,
+                                     int optLevel)
 {
-    int const contribution = ZSTD_rawLiteralsCost(literals, litLength, optPtr)
-                           + ZSTD_litLengthContribution(litLength, optPtr);
+    int const contribution = ZSTD_rawLiteralsCost(literals, litLength, optPtr, optLevel)
+                           + ZSTD_litLengthContribution(litLength, optPtr, optLevel);
     return contribution;
 }
 
@@ -243,7 +267,8 @@ static int ZSTD_literalsContribution(const BYTE* const literals, U32 const litLe
  * Must be combined with ZSTD_fullLiteralsCost() to get the full cost of a sequence.
  * optLevel: when <2, favors small offset for decompression speed (improved cache efficiency) */
 FORCE_INLINE_TEMPLATE U32
-ZSTD_getMatchPrice(U32 const offset, U32 const matchLength,
+ZSTD_getMatchPrice(U32 const offset,
+                   U32 const matchLength,
                    const optState_t* const optPtr,
                    int const optLevel)
 {
@@ -253,19 +278,20 @@ ZSTD_getMatchPrice(U32 const offset, U32 const matchLength,
     assert(matchLength >= MINMATCH);
 
     if (optPtr->priceType == zop_predef)  /* fixed scheme, do not use statistics */
-        return ZSTD_highbit32(mlBase+1) + 16 + offCode;
+        return WEIGHT(mlBase, optLevel) + ((16 + offCode) * BITCOST_MULTIPLIER);
 
     /* dynamic statistics */
-    price = offCode + optPtr->log2offCodeSum - ZSTD_highbit32(optPtr->offCodeFreq[offCode]+1);
-    if ((optLevel<2) /*static*/ && offCode >= 20) price += (offCode-19)*2; /* handicap for long distance offsets, favor decompression speed */
+    price = (offCode * BITCOST_MULTIPLIER) + (optPtr->offCodeSumBasePrice - WEIGHT(optPtr->offCodeFreq[offCode], optLevel));
+    if ((optLevel<2) /*static*/ && offCode >= 20)
+        price += (offCode-19)*2 * BITCOST_MULTIPLIER; /* handicap for long distance offsets, favor decompression speed */
 
     /* match Length */
     {   U32 const mlCode = ZSTD_MLcode(mlBase);
-        price += ML_bits[mlCode] + optPtr->log2matchLengthSum - ZSTD_highbit32(optPtr->matchLengthFreq[mlCode]+1);
+        price += (ML_bits[mlCode] * BITCOST_MULTIPLIER) + (optPtr->matchLengthSumBasePrice - WEIGHT(optPtr->matchLengthFreq[mlCode], optLevel));
     }
 
     DEBUGLOG(8, "ZSTD_getMatchPrice(ml:%u) = %u", matchLength, price);
-    return price * BITCOST_MULTIPLIER;
+    return price;
 }
 
 static void ZSTD_updateStats(optState_t* const optPtr,
@@ -695,7 +721,8 @@ typedef struct {
 static U32 ZSTD_rawLiteralsCost_cached(
                             cachedLiteralPrice_t* const cachedLitPrice,
                             const BYTE* const anchor, U32 const litlen,
-                            const optState_t* const optStatePtr)
+                            const optState_t* const optStatePtr,
+                            int optLevel)
 {
     U32 startCost;
     U32 remainingLength;
@@ -712,7 +739,7 @@ static U32 ZSTD_rawLiteralsCost_cached(
         remainingLength = litlen;
     }
 
-    {   U32 const rawLitCost = startCost + ZSTD_rawLiteralsCost(startPosition, remainingLength, optStatePtr);
+    {   U32 const rawLitCost = startCost + ZSTD_rawLiteralsCost(startPosition, remainingLength, optStatePtr, optLevel);
         cachedLitPrice->anchor = anchor;
         cachedLitPrice->litlen = litlen;
         cachedLitPrice->rawLitCost = rawLitCost;
@@ -723,19 +750,21 @@ static U32 ZSTD_rawLiteralsCost_cached(
 static U32 ZSTD_fullLiteralsCost_cached(
                             cachedLiteralPrice_t* const cachedLitPrice,
                             const BYTE* const anchor, U32 const litlen,
-                            const optState_t* const optStatePtr)
+                            const optState_t* const optStatePtr,
+                            int optLevel)
 {
-    return ZSTD_rawLiteralsCost_cached(cachedLitPrice, anchor, litlen, optStatePtr)
-         + ZSTD_litLengthPrice(litlen, optStatePtr);
+    return ZSTD_rawLiteralsCost_cached(cachedLitPrice, anchor, litlen, optStatePtr, optLevel)
+         + ZSTD_litLengthPrice(litlen, optStatePtr, optLevel);
 }
 
 static int ZSTD_literalsContribution_cached(
                             cachedLiteralPrice_t* const cachedLitPrice,
                             const BYTE* const anchor, U32 const litlen,
-                            const optState_t* const optStatePtr)
+                            const optState_t* const optStatePtr,
+                            int optLevel)
 {
-    int const contribution = ZSTD_rawLiteralsCost_cached(cachedLitPrice, anchor, litlen, optStatePtr)
-                           + ZSTD_litLengthContribution(litlen, optStatePtr);
+    int const contribution = ZSTD_rawLiteralsCost_cached(cachedLitPrice, anchor, litlen, optStatePtr, optLevel)
+                           + ZSTD_litLengthContribution(litlen, optStatePtr, optLevel);
     return contribution;
 }
 
@@ -765,8 +794,9 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms,
 
     /* init */
     DEBUGLOG(5, "ZSTD_compressBlock_opt_generic");
+    assert(optLevel <= 2);
     ms->nextToUpdate3 = ms->nextToUpdate;
-    ZSTD_rescaleFreqs(optStatePtr, (const BYTE*)src, srcSize);
+    ZSTD_rescaleFreqs(optStatePtr, (const BYTE*)src, srcSize, optLevel);
     ip += (ip==prefixStart);
     memset(&cachedLitPrice, 0, sizeof(cachedLitPrice));
 
@@ -803,7 +833,7 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms,
             }   }
 
             /* set prices for first matches starting position == 0 */
-            {   U32 const literalsPrice = ZSTD_fullLiteralsCost_cached(&cachedLitPrice, anchor, litlen, optStatePtr);
+            {   U32 const literalsPrice = ZSTD_fullLiteralsCost_cached(&cachedLitPrice, anchor, litlen, optStatePtr, optLevel);
                 U32 pos;
                 U32 matchNb;
                 for (pos = 1; pos < minMatch; pos++) {
@@ -838,9 +868,9 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms,
             {   U32 const litlen = (opt[cur-1].mlen == 1) ? opt[cur-1].litlen + 1 : 1;
                 int price;  /* note : contribution can be negative */
                 if (cur > litlen) {
-                    price = opt[cur - litlen].price + ZSTD_literalsContribution(inr-litlen, litlen, optStatePtr);
+                    price = opt[cur - litlen].price + ZSTD_literalsContribution(inr-litlen, litlen, optStatePtr, optLevel);
                 } else {
-                    price = ZSTD_literalsContribution_cached(&cachedLitPrice, anchor, litlen, optStatePtr);
+                    price = ZSTD_literalsContribution_cached(&cachedLitPrice, anchor, litlen, optStatePtr, optLevel);
                 }
                 assert(price < 1000000000); /* overflow check */
                 if (price <= opt[cur].price) {
@@ -871,7 +901,7 @@ ZSTD_compressBlock_opt_generic(ZSTD_matchState_t* ms,
             {   U32 const ll0 = (opt[cur].mlen != 1);
                 U32 const litlen = (opt[cur].mlen == 1) ? opt[cur].litlen : 0;
                 U32 const previousPrice = (cur > litlen) ? opt[cur-litlen].price : 0;
-                U32 const basePrice = previousPrice + ZSTD_fullLiteralsCost(inr-litlen, litlen, optStatePtr);
+                U32 const basePrice = previousPrice + ZSTD_fullLiteralsCost(inr-litlen, litlen, optStatePtr, optLevel);
                 U32 const nbMatches = ZSTD_BtGetAllMatches(ms, cParams, inr, iend, extDict, opt[cur].rep, ll0, matches, minMatch);
                 U32 matchNb;
                 if (!nbMatches) {
@@ -973,7 +1003,7 @@ _shortestPath:   /* cur, last_pos, best_mlen, best_off have to be set */
                 ZSTD_storeSeq(seqStore, llen, anchor, offset, mlen-MINMATCH);
                 anchor = ip;
         }   }
-        ZSTD_setLog2Prices(optStatePtr);
+        ZSTD_setBasePrices(optStatePtr, optLevel);
     }   /* while (ip < ilimit) */
 
     /* Return the last literals size */