]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
opt: estimate cost of both Hufman and FSE symbols
authorYann Collet <cyan@fb.com>
Tue, 8 May 2018 23:11:21 +0000 (16:11 -0700)
committerYann Collet <cyan@fb.com>
Tue, 8 May 2018 23:11:21 +0000 (16:11 -0700)
For FSE symbols : provide an upper bound,
in nb of bits,
since cost function is not able to store fractional bit costs.

lib/common/fse.h
lib/compress/zstd_opt.c

index 6a1d272be5cbd25a68bede8e711cf4519708aac9..556e6c523a5005487b9842a10e444a8f78c57cb3 100644 (file)
@@ -575,6 +575,12 @@ MEM_STATIC void FSE_flushCState(BIT_CStream_t* bitC, const FSE_CState_t* statePt
     BIT_flushBits(bitC);
 }
 
+MEM_STATIC U32 FSE_getMaxNbBits(const FSE_symbolCompressionTransform* symbolTT, U32 symbolValue)
+{
+    assert(symbolValue <= FSE_MAX_SYMBOL_VALUE);
+    return (symbolTT[symbolValue].deltaNbBits + ((1<<16)-1)) >> 16;
+}
+
 
 /* ======    Decompression    ====== */
 
index ecd676a68970a0a7bea32b86e54b5aa678228419..a107fccc24efaa979452c8d0b6cd5849aa401cae 100644 (file)
@@ -38,7 +38,7 @@ static void ZSTD_rescaleFreqs(optState_t* const optPtr,
         unsigned u;
         if (srcSize <= 1024) optPtr->priceType = zop_predef;
         assert(optPtr->symbolCosts != NULL);
-        if (0 && optPtr->symbolCosts->hufCTable_repeatMode == HUF_repeat_valid) { /* huffman table presumed generated by dictionary */
+        if (optPtr->symbolCosts->hufCTable_repeatMode == HUF_repeat_valid) { /* huffman table presumed generated by dictionary */
             optPtr->priceType = zop_static;
         }
 
@@ -122,12 +122,17 @@ static U32 ZSTD_rawLiteralsCost(const BYTE* const literals, U32 const litLength,
  * cost of literalLength symbol */
 static U32 ZSTD_litLengthPrice(U32 const litLength, const optState_t* const optPtr)
 {
+    if (optPtr->priceType == zop_static) {
+        U32 const llCode = ZSTD_LLcode(litLength);
+        FSE_CState_t cstate;
+        FSE_initCState(&cstate, optPtr->symbolCosts->litlengthCTable);
+        return LL_bits[llCode] + FSE_getMaxNbBits(cstate.symbolTT, llCode);
+    }
     if (optPtr->priceType == zop_predef) return ZSTD_highbit32((U32)litLength+1);
 
     /* literal Length */
     {   U32 const llCode = ZSTD_LLcode(litLength);
-        U32 const price = LL_bits[llCode] + optPtr->log2litLengthSum - ZSTD_highbit32(optPtr->litLengthFreq[llCode]+1);
-        return price;
+        return LL_bits[llCode] + optPtr->log2litLengthSum - ZSTD_highbit32(optPtr->litLengthFreq[llCode]+1);
     }
 }
 
@@ -147,7 +152,13 @@ static U32 ZSTD_fullLiteralsCost(const BYTE* const literals, U32 const litLength
  * to provide a cost which is directly comparable to a match ending at same position */
 static int ZSTD_litLengthContribution(U32 const litLength, const optState_t* const optPtr)
 {
-    if (optPtr->priceType == zop_predef) return ZSTD_highbit32(litLength+1);
+    if (optPtr->priceType == zop_static) {
+        U32 const llCode = ZSTD_LLcode(litLength);
+        FSE_CState_t cstate;
+        FSE_initCState(&cstate, optPtr->symbolCosts->litlengthCTable);
+        return (int)(LL_bits[llCode] + FSE_getMaxNbBits(cstate.symbolTT, llCode)) - FSE_getMaxNbBits(cstate.symbolTT, 0);
+    }
+    if (optPtr->priceType >= zop_predef) return ZSTD_highbit32(litLength+1);
 
     /* literal Length */
     {   U32 const llCode = ZSTD_LLcode(litLength);
@@ -188,6 +199,15 @@ ZSTD_getMatchPrice(U32 const offset, U32 const matchLength,
     U32 const mlBase = matchLength - MINMATCH;
     assert(matchLength >= MINMATCH);
 
+    if (optPtr->priceType == zop_static) {
+        U32 const mlCode = ZSTD_MLcode(mlBase);
+        FSE_CState_t mlstate, offstate;
+        FSE_initCState(&mlstate, optPtr->symbolCosts->matchlengthCTable);
+        FSE_initCState(&offstate, optPtr->symbolCosts->offcodeCTable);
+        return FSE_getMaxNbBits(offstate.symbolTT, offCode) + offCode
+             + FSE_getMaxNbBits(mlstate.symbolTT, mlCode) + ML_bits[mlCode];
+    }
+
     if (optPtr->priceType == zop_predef)  /* fixed scheme, do not use statistics */
         return ZSTD_highbit32(mlBase+1) + 16 + offCode;