]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Approximate FSE encoding costs for selection
authorNick Terrell <terrelln@fb.com>
Mon, 16 Apr 2018 22:37:27 +0000 (15:37 -0700)
committerNick Terrell <terrelln@fb.com>
Tue, 22 May 2018 21:33:22 +0000 (14:33 -0700)
Estimate the cost for using FSE modes `set_basic`, `set_compressed`, and
`set_repeat`, and select the one with the lowest cost.

* The cost of `set_basic` is computed using the cross-entropy cost
  function `ZSTD_crossEntropyCost()`, using the normalized default count
  and the count.
* The cost of `set_repeat` is computed using `FSE_bitCost()`. We check the
  previous table to see if it is able to represent the distribution.
* The cost of `set_compressed` is computed with the entropy cost function
  `ZSTD_entropyCost()`, together with the cost of writing the normalized
  count `ZSTD_NCountCost()`.

lib/common/fse.h
lib/compress/zstd_compress.c
lib/compress/zstd_compress_internal.h

index 5a2344441566896456552b5b1c8190b7d9523696..3d11a75e84e4dad4d7ae38a7e2feb7748275d771 100644 (file)
@@ -402,6 +402,7 @@ typedef struct {
     const void* stateTable;
     const void* symbolTT;
     unsigned    stateLog;
+    unsigned    maxSymbolValue;
 } FSE_CState_t;
 
 static void FSE_initCState(FSE_CState_t* CStatePtr, const FSE_CTable* ct);
@@ -538,11 +539,13 @@ MEM_STATIC void FSE_initCState(FSE_CState_t* statePtr, const FSE_CTable* ct)
 {
     const void* ptr = ct;
     const U16* u16ptr = (const U16*) ptr;
-    const U32 tableLog = MEM_read16(ptr);
+    const U32 tableLog = MEM_read16(u16ptr);
+    const U32 maxSymbolValue = MEM_read16(u16ptr + 1);
     statePtr->value = (ptrdiff_t)1<<tableLog;
     statePtr->stateTable = u16ptr+2;
     statePtr->symbolTT = ((const U32*)ct + 1 + (tableLog ? (1<<(tableLog-1)) : 1));
     statePtr->stateLog = tableLog;
+    statePtr->maxSymbolValue = maxSymbolValue;
 }
 
 
@@ -581,12 +584,13 @@ MEM_STATIC U32 FSE_getMaxNbBits(const void* symbolTTPtr, U32 symbolValue)
     return (symbolTT[symbolValue].deltaNbBits + ((1<<16)-1)) >> 16;
 }
 
-/* FSE_bitCost_b256() :
+/* FSE_bitCost() :
  * Approximate symbol cost,
  * provide fractional value, using fixed-point format (accuracyLog fractional bits)
  * note: assume symbolValue is valid */
-MEM_STATIC U32 FSE_bitCost(const FSE_symbolCompressionTransform* symbolTT, U32 tableLog, U32 symbolValue, U32 accuracyLog)
+MEM_STATIC U32 FSE_bitCost(const void* symbolTTPtr, U32 tableLog, U32 symbolValue, U32 accuracyLog)
 {
+    const FSE_symbolCompressionTransform* symbolTT = (const FSE_symbolCompressionTransform*) symbolTTPtr;
     U32 const minNbBits = symbolTT[symbolValue].deltaNbBits >> 16;
     U32 const threshold = (minNbBits+1) << 16;
     assert(tableLog < 16);
index d8420a8a65a1de731767bfc119cf3934ec3f2f8f..114845b150ca7832c656b9593cc66afd96c7bc58 100644 (file)
@@ -1561,6 +1561,129 @@ void ZSTD_seqToCodes(const seqStore_t* seqStorePtr)
         mlCodeTable[seqStorePtr->longLengthPos] = MaxML;
 }
 
+
+/**
+ * -log2(x / 256) lookup table for x in [0, 256).
+ * If x == 0: Return 0
+ * Else: Return floor(-log2(x / 256) * 256)
+ */
+static unsigned const kInverseProbabiltyLog256[256] = {
+    0,    2048, 1792, 1642, 1536, 1453, 1386, 1329, 1280, 1236, 1197, 1162,
+    1130, 1100, 1073, 1047, 1024, 1001, 980,  960,  941,  923,  906,  889,
+    874,  859,  844,  830,  817,  804,  791,  779,  768,  756,  745,  734,
+    724,  714,  704,  694,  685,  676,  667,  658,  650,  642,  633,  626,
+    618,  610,  603,  595,  588,  581,  574,  567,  561,  554,  548,  542,
+    535,  529,  523,  517,  512,  506,  500,  495,  489,  484,  478,  473,
+    468,  463,  458,  453,  448,  443,  438,  434,  429,  424,  420,  415,
+    411,  407,  402,  398,  394,  390,  386,  382,  377,  373,  370,  366,
+    362,  358,  354,  350,  347,  343,  339,  336,  332,  329,  325,  322,
+    318,  315,  311,  308,  305,  302,  298,  295,  292,  289,  286,  282,
+    279,  276,  273,  270,  267,  264,  261,  258,  256,  253,  250,  247,
+    244,  241,  239,  236,  233,  230,  228,  225,  222,  220,  217,  215,
+    212,  209,  207,  204,  202,  199,  197,  194,  192,  190,  187,  185,
+    182,  180,  178,  175,  173,  171,  168,  166,  164,  162,  159,  157,
+    155,  153,  151,  149,  146,  144,  142,  140,  138,  136,  134,  132,
+    130,  128,  126,  123,  121,  119,  117,  115,  114,  112,  110,  108,
+    106,  104,  102,  100,  98,   96,   94,   93,   91,   89,   87,   85,
+    83,   82,   80,   78,   76,   74,   73,   71,   69,   67,   66,   64,
+    62,   61,   59,   57,   55,   54,   52,   50,   49,   47,   46,   44,
+    42,   41,   39,   37,   36,   34,   33,   31,   30,   28,   26,   25,
+    23,   22,   20,   19,   17,   16,   14,   13,   11,   10,   8,    7,
+    5,    4,    2,    1,
+};
+
+
+/**
+ * Returns the cost in bits of encoding the distribution described by count
+ * using the entropy bound.
+ */
+static size_t ZSTD_entropyCost(unsigned const* count, unsigned const max, size_t const total)
+{
+    unsigned cost = 0;
+    unsigned s;
+    for (s = 0; s <= max; ++s) {
+        unsigned norm = (unsigned)((256 * count[s]) / total);
+        if (count[s] != 0 && norm == 0)
+            norm = 1;
+        assert(count[s] < total);
+        cost += count[s] * kInverseProbabiltyLog256[norm];
+    }
+    return cost >> 8;
+}
+
+
+/**
+ * Returns the cost in bits of encoding the distribution in count using the
+ * table described by norm. The max symbol support by norm is assumed >= max.
+ * norm must be valid for every symbol with non-zero probability in count.
+ */
+static size_t ZSTD_crossEntropyCost(short const* norm, unsigned accuracyLog,
+                                    unsigned const* count, unsigned const max)
+{
+    unsigned const shift = 8 - accuracyLog;
+    size_t cost = 0;
+    unsigned s;
+    assert(accuracyLog <= 8);
+    for (s = 0; s <= max; ++s) {
+        unsigned const normAcc = norm[s] != -1 ? norm[s] : 1;
+        unsigned const norm256 = normAcc << shift;
+        assert(norm256 > 0);
+        assert(norm256 < 256);
+        cost += count[s] * kInverseProbabiltyLog256[norm256];
+    }
+    return cost >> 8;
+}
+
+
+/**
+ * Returns the cost in bits of encoding the distribution in count using ctable.
+ * Returns an error if ctable cannot represent all the symbols in count.
+ */
+static size_t ZSTD_fseBitCost(
+    FSE_CTable const* ctable,
+    unsigned const* count,
+    unsigned const max)
+{
+    unsigned const kAccuracyLog = 8;
+    size_t cost = 0;
+    unsigned s;
+    FSE_CState_t cstate;
+    FSE_initCState(&cstate, ctable);
+    if (cstate.maxSymbolValue < max) {
+        DEBUGLOG(5, "Repeat FSE_CTable has maxSymbolValue %u < %u",
+                    cstate.maxSymbolValue, max);
+        return ERROR(GENERIC);
+    }
+    for (s = 0; s <= max; ++s) {
+        unsigned const tableLog = cstate.stateLog;
+        unsigned const badCost = (tableLog + 1) << kAccuracyLog;
+        unsigned const bitCost = FSE_bitCost(cstate.symbolTT, tableLog, s, kAccuracyLog);
+        if (count[s] == 0)
+            continue;
+        if (bitCost >= badCost) {
+            DEBUGLOG(5, "Repeat FSE_CTable has Prob[%u] == 0", s);
+            return ERROR(GENERIC);
+        }
+        cost += count[s] * bitCost;
+    }
+    return cost >> kAccuracyLog;
+}
+
+/**
+ * Returns the cost in bytes of encoding the normalized count header.
+ * Returns an error if any of the helper functions return an error.
+ */
+static size_t ZSTD_NCountCost(unsigned const* count, unsigned const max,
+                              size_t const nbSeq, unsigned const FSELog)
+{
+    BYTE wksp[FSE_NCOUNTBOUND];
+    S16 norm[MaxSeq + 1];
+    const U32 tableLog = FSE_optimalTableLog(FSELog, nbSeq, max);
+    CHECK_F(FSE_normalizeCount(norm, tableLog, count, nbSeq, max));
+    return FSE_writeNCount(wksp, sizeof(wksp), norm, max, tableLog);
+}
+
+
 typedef enum {
     ZSTD_defaultDisallowed = 0,
     ZSTD_defaultAllowed = 1
@@ -1568,37 +1691,73 @@ typedef enum {
 
 MEM_STATIC
 symbolEncodingType_e ZSTD_selectEncodingType(
-        FSE_repeat* repeatMode, size_t const mostFrequent, size_t nbSeq,
-        U32 defaultNormLog, ZSTD_defaultPolicy_e const isDefaultAllowed)
+        FSE_repeat* repeatMode, unsigned const* count, unsigned const max,
+        size_t const mostFrequent, size_t nbSeq, unsigned const FSELog,
+        FSE_CTable const* prevCTable,
+        short const* defaultNorm, U32 defaultNormLog,
+        ZSTD_defaultPolicy_e const isDefaultAllowed,
+        ZSTD_strategy const strategy)
 {
 #define MIN_SEQ_FOR_DYNAMIC_FSE   64
 #define MAX_SEQ_FOR_STATIC_FSE  1000
     ZSTD_STATIC_ASSERT(ZSTD_defaultDisallowed == 0 && ZSTD_defaultAllowed != 0);
-    if ((mostFrequent == nbSeq) && (!isDefaultAllowed || nbSeq > 2)) {
+    if (mostFrequent == nbSeq) {
+        *repeatMode = FSE_repeat_none;
+        if (isDefaultAllowed && nbSeq <= 2) {
+            /* Prefer set_basic over set_rle when there are 2 or less symbols,
+             * since RLE uses 1 byte, but set_basic uses 5-6 bits per symbol.
+             * If basic encoding isn't possible, always choose RLE.
+             */
+            DEBUGLOG(5, "Selected set_basic");
+            return set_basic;
+        }
         DEBUGLOG(5, "Selected set_rle");
-        /* Prefer set_basic over set_rle when there are 2 or less symbols,
-         * since RLE uses 1 byte, but set_basic uses 5-6 bits per symbol.
-         * If basic encoding isn't possible, always choose RLE.
-         */
-        *repeatMode = FSE_repeat_check;
         return set_rle;
     }
-    if ( isDefaultAllowed
-      && (*repeatMode == FSE_repeat_valid) && (nbSeq < MAX_SEQ_FOR_STATIC_FSE)) {
-        DEBUGLOG(5, "Selected set_repeat");
-        return set_repeat;
-    }
-    if ( isDefaultAllowed
-      && ((nbSeq < MIN_SEQ_FOR_DYNAMIC_FSE) || (mostFrequent < (nbSeq >> (defaultNormLog-1)))) ) {
-        DEBUGLOG(5, "Selected set_basic");
-        /* The format allows default tables to be repeated, but it isn't useful.
-         * When using simple heuristics to select encoding type, we don't want
-         * to confuse these tables with dictionaries. When running more careful
-         * analysis, we don't need to waste time checking both repeating tables
-         * and default tables.
-         */
-        *repeatMode = FSE_repeat_none;
-        return set_basic;
+    if (strategy < ZSTD_lazy) {
+        if (isDefaultAllowed) {
+            if ((*repeatMode == FSE_repeat_valid) && (nbSeq < MAX_SEQ_FOR_STATIC_FSE)) {
+                DEBUGLOG(5, "Selected set_repeat");
+                return set_repeat;
+            }
+            if ((nbSeq < MIN_SEQ_FOR_DYNAMIC_FSE) || (mostFrequent < (nbSeq >> (defaultNormLog-1)))) {
+                DEBUGLOG(5, "Selected set_basic");
+                /* The format allows default tables to be repeated, but it isn't useful.
+                 * When using simple heuristics to select encoding type, we don't want
+                 * to confuse these tables with dictionaries. When running more careful
+                 * analysis, we don't need to waste time checking both repeating tables
+                 * and default tables.
+                 */
+                *repeatMode = FSE_repeat_none;
+                return set_basic;
+            }
+        }
+    } else {
+        size_t const basicCost = isDefaultAllowed ? ZSTD_crossEntropyCost(defaultNorm, defaultNormLog, count, max) : ERROR(GENERIC);
+        size_t const repeatCost = *repeatMode != FSE_repeat_none ? ZSTD_fseBitCost(prevCTable, count, max) : ERROR(GENERIC);
+        size_t const NCountCost = ZSTD_NCountCost(count, max, nbSeq, FSELog);
+        size_t const compressedCost = (NCountCost << 3) + ZSTD_entropyCost(count, max, nbSeq);
+
+        if (isDefaultAllowed) {
+            assert(!ZSTD_isError(basicCost));
+            assert(!(*repeatMode == FSE_repeat_valid && ZSTD_isError(repeatCost)));
+        }
+        assert(!ZSTD_isError(NCountCost));
+        assert(compressedCost < ERROR(maxCode));
+        DEBUGLOG(5, "Estimated bit costs: basic=%u\trepeat=%u\tcompressed=%u",
+                    (U32)basicCost, (U32)repeatCost, (U32)compressedCost);
+        if (basicCost <= repeatCost && basicCost <= compressedCost) {
+            DEBUGLOG(5, "Selected set_basic");
+            assert(isDefaultAllowed);
+            *repeatMode = FSE_repeat_none;
+            return set_basic;
+        }
+        if (repeatCost <= compressedCost) {
+            DEBUGLOG(5, "Selected set_repeat");
+            assert(!ZSTD_isError(repeatCost));
+            return set_repeat;
+        }
+        assert(compressedCost < basicCost && compressedCost < repeatCost);
     }
     DEBUGLOG(5, "Selected set_compressed");
     *repeatMode = FSE_repeat_check;
@@ -1803,6 +1962,7 @@ MEM_STATIC size_t ZSTD_compressSequences_internal(seqStore_t* seqStorePtr,
                               const int bmi2)
 {
     const int longOffsets = cctxParams->cParams.windowLog > STREAM_ACCUMULATOR_MIN;
+    ZSTD_strategy const strategy = cctxParams->cParams.strategy;
     U32 count[MaxSeq+1];
     FSE_CTable* CTable_LitLength = nextEntropy->litlengthCTable;
     FSE_CTable* CTable_OffsetBits = nextEntropy->offcodeCTable;
@@ -1844,13 +2004,20 @@ MEM_STATIC size_t ZSTD_compressSequences_internal(seqStore_t* seqStorePtr,
     else
         op[0]=0xFF, MEM_writeLE16(op+1, (U16)(nbSeq - LONGNBSEQ)), op+=3;
     if (nbSeq==0) {
-      memcpy(nextEntropy->litlengthCTable, prevEntropy->litlengthCTable, sizeof(prevEntropy->litlengthCTable));
-      nextEntropy->litlength_repeatMode = prevEntropy->litlength_repeatMode;
-      memcpy(nextEntropy->offcodeCTable, prevEntropy->offcodeCTable, sizeof(prevEntropy->offcodeCTable));
-      nextEntropy->offcode_repeatMode = prevEntropy->offcode_repeatMode;
-      memcpy(nextEntropy->matchlengthCTable, prevEntropy->matchlengthCTable, sizeof(prevEntropy->matchlengthCTable));
-      nextEntropy->matchlength_repeatMode = prevEntropy->matchlength_repeatMode;
-      return op - ostart;
+        /* Check that all the Huffman data is first */
+        ZSTD_STATIC_ASSERT(offsetof(ZSTD_entropyCTables_t, hufCTable) == 0);
+        ZSTD_STATIC_ASSERT(
+            offsetof(ZSTD_entropyCTables_t, hufCTable_repeatMode) ==
+            sizeof(prevEntropy->hufCTable));
+        ZSTD_STATIC_ASSERT(
+            offsetof(ZSTD_entropyCTables_t, offcodeCTable) ==
+            sizeof(prevEntropy->hufCTable) + sizeof(prevEntropy->hufCTable_repeatMode));
+        /* Copy starting at the first FSE element */
+        memcpy(
+            nextEntropy->offcodeCTable,
+            prevEntropy->offcodeCTable,
+            sizeof(*prevEntropy) - offsetof(ZSTD_entropyCTables_t, offcodeCTable));
+        return op - ostart;
     }
 
     /* seqHead : flags for FSE encoding type */
@@ -1863,7 +2030,9 @@ MEM_STATIC size_t ZSTD_compressSequences_internal(seqStore_t* seqStorePtr,
         size_t const mostFrequent = FSE_countFast_wksp(count, &max, llCodeTable, nbSeq, workspace);
         DEBUGLOG(5, "Building LL table");
         nextEntropy->litlength_repeatMode = prevEntropy->litlength_repeatMode;
-        LLtype = ZSTD_selectEncodingType(&nextEntropy->litlength_repeatMode, mostFrequent, nbSeq, LL_defaultNormLog, ZSTD_defaultAllowed);
+        LLtype = ZSTD_selectEncodingType(&nextEntropy->litlength_repeatMode, count, max, mostFrequent, nbSeq, LLFSELog, prevEntropy->litlengthCTable, LL_defaultNorm, LL_defaultNormLog, ZSTD_defaultAllowed, strategy);
+        assert(set_basic < set_compressed && set_rle < set_compressed);
+        assert(!(LLtype < set_compressed && nextEntropy->litlength_repeatMode != FSE_repeat_none)); /* We don't copy tables */
         {   size_t const countSize = ZSTD_buildCTable(op, oend - op, CTable_LitLength, LLFSELog, (symbolEncodingType_e)LLtype,
                     count, max, llCodeTable, nbSeq, LL_defaultNorm, LL_defaultNormLog, MaxLL,
                     prevEntropy->litlengthCTable, sizeof(prevEntropy->litlengthCTable),
@@ -1878,7 +2047,8 @@ MEM_STATIC size_t ZSTD_compressSequences_internal(seqStore_t* seqStorePtr,
         ZSTD_defaultPolicy_e const defaultPolicy = (max <= DefaultMaxOff) ? ZSTD_defaultAllowed : ZSTD_defaultDisallowed;
         DEBUGLOG(5, "Building OF table");
         nextEntropy->offcode_repeatMode = prevEntropy->offcode_repeatMode;
-        Offtype = ZSTD_selectEncodingType(&nextEntropy->offcode_repeatMode, mostFrequent, nbSeq, OF_defaultNormLog, defaultPolicy);
+        Offtype = ZSTD_selectEncodingType(&nextEntropy->offcode_repeatMode, count, max, mostFrequent, nbSeq, OffFSELog, prevEntropy->offcodeCTable, OF_defaultNorm, OF_defaultNormLog, defaultPolicy, strategy);
+        assert(!(Offtype < set_compressed && nextEntropy->offcode_repeatMode != FSE_repeat_none)); /* We don't copy tables */
         {   size_t const countSize = ZSTD_buildCTable(op, oend - op, CTable_OffsetBits, OffFSELog, (symbolEncodingType_e)Offtype,
                     count, max, ofCodeTable, nbSeq, OF_defaultNorm, OF_defaultNormLog, DefaultMaxOff,
                     prevEntropy->offcodeCTable, sizeof(prevEntropy->offcodeCTable),
@@ -1891,7 +2061,8 @@ MEM_STATIC size_t ZSTD_compressSequences_internal(seqStore_t* seqStorePtr,
         size_t const mostFrequent = FSE_countFast_wksp(count, &max, mlCodeTable, nbSeq, workspace);
         DEBUGLOG(5, "Building ML table");
         nextEntropy->matchlength_repeatMode = prevEntropy->matchlength_repeatMode;
-        MLtype = ZSTD_selectEncodingType(&nextEntropy->matchlength_repeatMode, mostFrequent, nbSeq, ML_defaultNormLog, ZSTD_defaultAllowed);
+        MLtype = ZSTD_selectEncodingType(&nextEntropy->matchlength_repeatMode, count, max, mostFrequent, nbSeq, MLFSELog, prevEntropy->matchlengthCTable, ML_defaultNorm, ML_defaultNormLog, ZSTD_defaultAllowed, strategy);
+        assert(!(MLtype < set_compressed && nextEntropy->matchlength_repeatMode != FSE_repeat_none)); /* We don't copy tables */
         {   size_t const countSize = ZSTD_buildCTable(op, oend - op, CTable_MatchLength, MLFSELog, (symbolEncodingType_e)MLtype,
                     count, max, mlCodeTable, nbSeq, ML_defaultNorm, ML_defaultNormLog, MaxML,
                     prevEntropy->matchlengthCTable, sizeof(prevEntropy->matchlengthCTable),
index 0f1830a5e691a28f507a88611a39388e15d7197c..6c4e8bdc38170608f36a5ef3c620b37989035ca5 100644 (file)
@@ -53,11 +53,15 @@ typedef struct ZSTD_prefixDict_s {
 } ZSTD_prefixDict;
 
 typedef struct {
+    /* Huffman data
+     * Must be before the FSE data.
+     */
     U32 hufCTable[HUF_CTABLE_SIZE_U32(255)];
+    HUF_repeat hufCTable_repeatMode;
+    /* FSE data */
     FSE_CTable offcodeCTable[FSE_CTABLE_SIZE_U32(OffFSELog, MaxOff)];
     FSE_CTable matchlengthCTable[FSE_CTABLE_SIZE_U32(MLFSELog, MaxML)];
     FSE_CTable litlengthCTable[FSE_CTABLE_SIZE_U32(LLFSELog, MaxLL)];
-    HUF_repeat hufCTable_repeatMode;
     FSE_repeat offcode_repeatMode;
     FSE_repeat matchlength_repeatMode;
     FSE_repeat litlength_repeatMode;