]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
fix corner case when requiring cost of an FSE symbol
authorYann Collet <cyan@fb.com>
Thu, 24 May 2018 20:59:11 +0000 (13:59 -0700)
committerYann Collet <cyan@fb.com>
Thu, 24 May 2018 20:59:11 +0000 (13:59 -0700)
ensure that, when frequency[symbol]==0,
result is (tableLog + 1) bits
with both upper-bit and fractional-bit estimates.

Also : enable BIT_DEBUG in /tests

lib/common/fse.h
lib/compress/fse_compress.c
tests/Makefile

index 5a2344441566896456552b5b1c8190b7d9523696..cd810c7d5f2dd7927ced081d8fa61a1e71ee44b9 100644 (file)
@@ -575,16 +575,22 @@ MEM_STATIC void FSE_flushCState(BIT_CStream_t* bitC, const FSE_CState_t* statePt
     BIT_flushBits(bitC);
 }
 
+
+/* FSE_getMaxNbBits() :
+ * Approximate maximum cost of a symbol, in bits.
+ * Fractional get rounded up (i.e : a symbol with a normalized frequency of 3 gives the same result as a frequency of 2)
+ * note 1 : assume symbolValue is valid (<= maxSymbolValue)
+ * note 2 : if freq[symbolValue]==0, @return a fake cost of tableLog+1 bits */
 MEM_STATIC U32 FSE_getMaxNbBits(const void* symbolTTPtr, U32 symbolValue)
 {
     const FSE_symbolCompressionTransform* symbolTT = (const FSE_symbolCompressionTransform*) symbolTTPtr;
     return (symbolTT[symbolValue].deltaNbBits + ((1<<16)-1)) >> 16;
 }
 
-/* FSE_bitCost_b256() :
- * Approximate symbol cost,
- * provide fractional value, using fixed-point format (accuracyLog fractional bits)
- * note: assume symbolValue is valid */
+/* FSE_bitCost() :
+ * Approximate symbol cost, as fractional value, using fixed-point format (accuracyLog fractional bits)
+ * note 1 : assume symbolValue is valid (<= maxSymbolValue)
+ * note 2 : if freq[symbolValue]==0, @return a fake cost of tableLog+1 bits */
 MEM_STATIC U32 FSE_bitCost(const FSE_symbolCompressionTransform* symbolTT, U32 tableLog, U32 symbolValue, U32 accuracyLog)
 {
     U32 const minNbBits = symbolTT[symbolValue].deltaNbBits >> 16;
@@ -592,13 +598,13 @@ MEM_STATIC U32 FSE_bitCost(const FSE_symbolCompressionTransform* symbolTT, U32 t
     assert(tableLog < 16);
     assert(accuracyLog < 31-tableLog);  /* ensure enough room for renormalization double shift */
     {   U32 const tableSize = 1 << tableLog;
+        U32 const deltaFromThreshold = threshold - (symbolTT[symbolValue].deltaNbBits + tableSize);
+        U32 const normalizedDeltaFromThreshold = (deltaFromThreshold << accuracyLog) >> tableLog;   /* linear interpolation (very approximate) */
+        U32 const bitMultiplier = 1 << accuracyLog;
         assert(symbolTT[symbolValue].deltaNbBits + tableSize <= threshold);
-        {   U32 const deltaFromThreshold = threshold - (symbolTT[symbolValue].deltaNbBits + tableSize);
-            U32 const normalizedDeltaFromThreshold = (deltaFromThreshold << accuracyLog) >> tableLog;   /* linear interpolation (very approximate) */
-            U32 const bitMultiplier = 1 << accuracyLog;
-            assert(normalizedDeltaFromThreshold <= bitMultiplier);
-            return (minNbBits+1)*bitMultiplier - normalizedDeltaFromThreshold;
-    }   }
+        assert(normalizedDeltaFromThreshold <= bitMultiplier);
+        return (minNbBits+1)*bitMultiplier - normalizedDeltaFromThreshold;
+    }
 }
 
 
index 5df92db454292382da68c8db23e4d60432a632fd..ab2eb335fa1fdbe53136473416ebd29db32a4247 100644 (file)
@@ -100,6 +100,7 @@ size_t FSE_buildCTable_wksp(FSE_CTable* ct, const short* normalizedCounter, unsi
     if (((size_t)1 << tableLog) * sizeof(FSE_FUNCTION_TYPE) > wkspSize) return ERROR(tableLog_tooLarge);
     tableU16[-2] = (U16) tableLog;
     tableU16[-1] = (U16) maxSymbolValue;
+    assert(tableLog < 16);   /* required for the threshold strategy to work */
 
     /* For explanations on how to distribute symbol values over the table :
     *  http://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html */
@@ -145,7 +146,7 @@ size_t FSE_buildCTable_wksp(FSE_CTable* ct, const short* normalizedCounter, unsi
             {
             case  0:
                 /* filling nonetheless, for compatibility with FSE_getMaxNbBits() */
-                symbolTT[s].deltaNbBits = (tableLog+1) << 16;
+                symbolTT[s].deltaNbBits = ((tableLog+1) << 16) - (1<<tableLog);
                 break;
 
             case -1:
index 5b35ad406318d1c43d85a8d4155fbd11190229c4..c4cbe1bdfcd8f555d29867a41198a444d93ffe10 100644 (file)
@@ -24,7 +24,7 @@ PYTHON ?= python3
 TESTARTEFACT := versionsTest
 
 DEBUGLEVEL ?= 1
-DEBUGFLAGS  = -g -DZSTD_DEBUG=$(DEBUGLEVEL)
+DEBUGFLAGS  = -g -DZSTD_DEBUG=$(DEBUGLEVEL) -DBIT_DEBUG=$(DEBUGLEVEL)
 CPPFLAGS   += -I$(ZSTDDIR) -I$(ZSTDDIR)/common -I$(ZSTDDIR)/compress \
               -I$(ZSTDDIR)/dictBuilder -I$(ZSTDDIR)/deprecated -I$(PRGDIR)
 CFLAGS     ?= -O3