]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
[fuzz] Only set HUF_repeat_valid if loaded table has all non-zero weights (#1898)
authorBimba Shrestha <bimbashrestha@fb.com>
Tue, 26 Nov 2019 20:24:19 +0000 (12:24 -0800)
committerNick Terrell <terrelln@fb.com>
Tue, 26 Nov 2019 20:24:19 +0000 (12:24 -0800)
Fixes a fuzz issue where dictionary_round_trip failed because the compressor was generating corrupt files thanks to zero weights in the table.

* Only setting loaded dict huf table to valid on non-zero

* Adding hasNoZeroWeights test to fse tables

* Forbiding nbBits != 0 when weight == 0

* Reverting the last commit

* Setting table log to 0 when weight == 0

* Small (invalid) zero weight dict test

* Small (valid) zero weight dict test

* Initializing repeatMode vars to check before zero check

* Removing FSE changes to seperate pr

* Reverting accidentally changed file

* Negating bool, using unsigned, optimization nit

lib/common/huf.h
lib/compress/huf_compress.c
lib/compress/zstd_compress.c
tests/dict-files/zero-weight-dict [new file with mode: 0644]
tests/playTests.sh

index 3026c43ea03c8446271e54eac3e0118665b56035..4a87db5c1625d2767b6796767d89a23b948af12c 100644 (file)
@@ -247,7 +247,7 @@ size_t HUF_readStats(BYTE* huffWeight, size_t hwSize,
 
 /** HUF_readCTable() :
  *  Loading a CTable saved with HUF_writeCTable() */
-size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize);
+size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize, unsigned *hasZeroWeights);
 
 /** HUF_getNbBits() :
  *  Read nbBits from CTable symbolTable, for symbol `symbolValue` presumed <= HUF_SYMBOLVALUE_MAX
index 0cbba2c994e37934edcf0663eab210a29858bc2a..b8e6fb386767b0562b8728878c85f6bcc47a94ef 100644 (file)
@@ -169,7 +169,7 @@ size_t HUF_writeCTable (void* dst, size_t maxDstSize,
 }
 
 
-size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize)
+size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize, unsigned* hasZeroWeights)
 {
     BYTE huffWeight[HUF_SYMBOLVALUE_MAX + 1];   /* init not required, even though some static analyzer may complain */
     U32 rankVal[HUF_TABLELOG_ABSOLUTEMAX + 1];   /* large enough for values from 0 to 16 */
@@ -192,9 +192,11 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void
     }   }
 
     /* fill nbBits */
+    *hasZeroWeights = 0;
     {   U32 n; for (n=0; n<nbSymbols; n++) {
             const U32 w = huffWeight[n];
-            CTable[n].nbBits = (BYTE)(tableLog + 1 - w);
+            *hasZeroWeights |= (w == 0);
+            CTable[n].nbBits = (BYTE)(tableLog + 1 - w) & -(w != 0);
     }   }
 
     /* fill val */
index 5dacba8cf129757fc3964b3954f837fe63278972..682c9c047d03e2495f4ac7be8059fa0643a4066c 100644 (file)
@@ -2853,14 +2853,23 @@ static size_t ZSTD_checkDictNCount(short* normalizedCounter, unsigned dictMaxSym
 
 size_t ZSTD_loadCEntropy(ZSTD_compressedBlockState_t* bs, void* workspace,
                          short* offcodeNCount, unsigned* offcodeMaxValue,
-                         const void* const dict, size_t dictSize) 
+                         const void* const dict, size_t dictSize)
 {
     const BYTE* dictPtr = (const BYTE*)dict;    /* skip magic num and dict ID */
     const BYTE* const dictEnd = dictPtr + dictSize;
     dictPtr += 8;
 
     {   unsigned maxSymbolValue = 255;
-        size_t const hufHeaderSize = HUF_readCTable((HUF_CElt*)bs->entropy.huf.CTable, &maxSymbolValue, dictPtr, dictEnd-dictPtr);
+        unsigned hasZeroWeights;
+        size_t const hufHeaderSize = HUF_readCTable((HUF_CElt*)bs->entropy.huf.CTable, &maxSymbolValue, dictPtr,
+            dictEnd-dictPtr, &hasZeroWeights);
+
+        /* We only set the loaded table as valid if it contains all non-zero
+         * weights. Otherwise, we set it to check */
+        if (!hasZeroWeights)
+            bs->entropy.huf.repeatMode = HUF_repeat_valid;
+        else bs->entropy.huf.repeatMode = HUF_repeat_check;
+
         RETURN_ERROR_IF(HUF_isError(hufHeaderSize), dictionary_corrupted);
         RETURN_ERROR_IF(maxSymbolValue < 255, dictionary_corrupted);
         dictPtr += hufHeaderSize;
@@ -2967,7 +2976,6 @@ static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs,
                 RETURN_ERROR_IF(bs->rep[u] > dictContentSize, dictionary_corrupted);
         }   }
 
-        bs->entropy.huf.repeatMode = HUF_repeat_valid;
         bs->entropy.fse.offcode_repeatMode = FSE_repeat_valid;
         bs->entropy.fse.matchlength_repeatMode = FSE_repeat_valid;
         bs->entropy.fse.litlength_repeatMode = FSE_repeat_valid;
diff --git a/tests/dict-files/zero-weight-dict b/tests/dict-files/zero-weight-dict
new file mode 100644 (file)
index 0000000..c404120
Binary files /dev/null and b/tests/dict-files/zero-weight-dict differ
index 2955251143c35f6fdb0f19b514a4c4c97ddc87f5..df9568eb50259efbb9082050b6353429b01dc89b 100755 (executable)
@@ -502,6 +502,22 @@ cmp tmp tmp_decompress || die "difference between original and decompressed file
 println "test : incorrect stream size"
 cat tmp | $ZSTD -14 -f -o tmp.zst --stream-size=11001 && die "should fail with incorrect stream size"
 
+println "\n===>  zstd zero weight dict test "
+rm -f tmp*
+cp "$TESTDIR/dict-files/zero-weight-dict" tmp_input
+$ZSTD -D "$TESTDIR/dict-files/zero-weight-dict" tmp_input
+$ZSTD -D "$TESTDIR/dict-files/zero-weight-dict" -d tmp_input.zst -o tmp_decomp
+$DIFF tmp_decomp tmp_input
+rm -rf tmp*
+
+println "\n===>  zstd (valid) zero weight dict test "
+rm -f tmp*
+# 0 has a non-zero weight in the dictionary
+echo "0000000000000000000000000" > tmp_input
+$ZSTD -D "$TESTDIR/dict-files/zero-weight-dict" tmp_input
+$ZSTD -D "$TESTDIR/dict-files/zero-weight-dict" -d tmp_input.zst -o tmp_decomp
+$DIFF tmp_decomp tmp_input
+rm -rf tmp*
 
 println "\n===>  size-hint mode"
 
@@ -1189,7 +1205,6 @@ $ZSTD --train-cover "$TESTDIR"/*.c "$PRGDIR"/*.c
 test -f dictionary
 rm -f tmp* dictionary
 
-
 if [ "$isWindows" = false ] ; then
 
 println "\n===>  zstd fifo named pipe test "