]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Fix & refactor Huffman repeat tables for dictionaries
authorNick Terrell <terrelln@meta.com>
Thu, 24 Aug 2023 21:41:21 +0000 (14:41 -0700)
committerNick Terrell <nickrterrell@gmail.com>
Fri, 25 Aug 2023 17:21:58 +0000 (13:21 -0400)
The Huffman repeat mode checker assumed that the CTable was zeroed in the region `[maxSymbolValue + 1, 256)`.
This assumption didn't hold for tables built in the dictionaries, because it didn't go through the same codepath.

Since this code was originally written, we added a header to the CTable that specifies the `tableLog`.
Add `maxSymbolValue` to that header, and check that the table's `maxSymbolValue` is at least the block's `maxSymbolValue`.

This solution is cleaner because we write this header for every CTable we build, so it can't be missed in any code path.

Credit to OSS-Fuzz

lib/common/huf.h
lib/compress/huf_compress.c

index 73d1ee565438eea5f6baef9c6abcdb82f327b5de..99bf85d6f4ed2a6ad02b911b5721eed979179147 100644 (file)
@@ -197,9 +197,22 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void
 
 /** HUF_getNbBitsFromCTable() :
  *  Read nbBits from CTable symbolTable, for symbol `symbolValue` presumed <= HUF_SYMBOLVALUE_MAX
- *  Note 1 : is not inlined, as HUF_CElt definition is private */
+ *  Note 1 : If symbolValue > HUF_readCTableHeader(symbolTable).maxSymbolValue, returns 0
+ *  Note 2 : is not inlined, as HUF_CElt definition is private
+ */
 U32 HUF_getNbBitsFromCTable(const HUF_CElt* symbolTable, U32 symbolValue);
 
+typedef struct {
+    BYTE tableLog;
+    BYTE maxSymbolValue;
+    BYTE unused[sizeof(size_t) - 2];
+} HUF_CTableHeader;
+
+/** HUF_readCTableHeader() :
+ *  @returns The header from the CTable specifying the tableLog and the maxSymbolValue.
+ */
+HUF_CTableHeader HUF_readCTableHeader(HUF_CElt const* ctable);
+
 /*
  * HUF_decompress() does the following:
  * 1. select the decompression algorithm (X1, X2) based on pre-computed heuristics
index 29871877a7f6b5d7150e57010f9d01655d65a554..3fe2578960339429cfd78da45075446d20b38fad 100644 (file)
@@ -220,6 +220,25 @@ static void HUF_setValue(HUF_CElt* elt, size_t value)
     }
 }
 
+HUF_CTableHeader HUF_readCTableHeader(HUF_CElt const* ctable)
+{
+    HUF_CTableHeader header;
+    ZSTD_memcpy(&header, ctable, sizeof(header));
+    return header;
+}
+
+static void HUF_writeCTableHeader(HUF_CElt* ctable, U32 tableLog, U32 maxSymbolValue)
+{
+    HUF_CTableHeader header;
+    HUF_STATIC_ASSERT(sizeof(ctable[0]) == sizeof(header));
+    ZSTD_memset(&header, 0, sizeof(header));
+    assert(tableLog < 256);
+    header.tableLog = (BYTE)tableLog;
+    assert(maxSymbolValue < 256);
+    header.maxSymbolValue = (BYTE)maxSymbolValue;
+    ZSTD_memcpy(ctable, &header, sizeof(header));
+}
+
 typedef struct {
     HUF_CompressWeightsWksp wksp;
     BYTE bitsToWeight[HUF_TABLELOG_MAX + 1];   /* precomputed conversion table */
@@ -237,6 +256,9 @@ size_t HUF_writeCTable_wksp(void* dst, size_t maxDstSize,
 
     HUF_STATIC_ASSERT(HUF_CTABLE_WORKSPACE_SIZE >= sizeof(HUF_WriteCTableWksp));
 
+    assert(HUF_readCTableHeader(CTable).maxSymbolValue == maxSymbolValue);
+    assert(HUF_readCTableHeader(CTable).tableLog == huffLog);
+
     /* check conditions */
     if (workspaceSize < sizeof(HUF_WriteCTableWksp)) return ERROR(GENERIC);
     if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) return ERROR(maxSymbolValue_tooLarge);
@@ -283,7 +305,9 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void
     if (tableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge);
     if (nbSymbols > *maxSymbolValuePtr+1) return ERROR(maxSymbolValue_tooSmall);
 
-    CTable[0] = tableLog;
+    *maxSymbolValuePtr = nbSymbols - 1;
+
+    HUF_writeCTableHeader(CTable, tableLog, *maxSymbolValuePtr);
 
     /* Prepare base value per rank */
     {   U32 n, nextRankStart = 0;
@@ -315,7 +339,6 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void
         { U32 n; for (n=0; n<nbSymbols; n++) HUF_setValue(ct + n, valPerRank[HUF_getNbBits(ct[n])]++); }
     }
 
-    *maxSymbolValuePtr = nbSymbols - 1;
     return readSize;
 }
 
@@ -323,6 +346,8 @@ U32 HUF_getNbBitsFromCTable(HUF_CElt const* CTable, U32 symbolValue)
 {
     const HUF_CElt* const ct = CTable + 1;
     assert(symbolValue <= HUF_SYMBOLVALUE_MAX);
+    if (symbolValue > HUF_readCTableHeader(CTable).maxSymbolValue)
+        return 0;
     return (U32)HUF_getNbBits(ct[symbolValue]);
 }
 
@@ -723,7 +748,8 @@ static void HUF_buildCTableFromTree(HUF_CElt* CTable, nodeElt const* huffNode, i
         HUF_setNbBits(ct + huffNode[n].byte, huffNode[n].nbBits);   /* push nbBits per symbol, symbol order */
     for (n=0; n<alphabetSize; n++)
         HUF_setValue(ct + n, valPerRank[HUF_getNbBits(ct[n])]++);   /* assign value within rank, symbol order */
-    CTable[0] = maxNbBits;
+
+    HUF_writeCTableHeader(CTable, maxNbBits, maxSymbolValue);
 }
 
 size_t
@@ -776,13 +802,20 @@ size_t HUF_estimateCompressedSize(const HUF_CElt* CTable, const unsigned* count,
 }
 
 int HUF_validateCTable(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue) {
-  HUF_CElt const* ct = CTable + 1;
-  int bad = 0;
-  int s;
-  for (s = 0; s <= (int)maxSymbolValue; ++s) {
-    bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0);
-  }
-  return !bad;
+    HUF_CTableHeader header = HUF_readCTableHeader(CTable);
+    HUF_CElt const* ct = CTable + 1;
+    int bad = 0;
+    int s;
+
+    assert(header.tableLog <= HUF_TABLELOG_ABSOLUTEMAX);
+
+    if (header.maxSymbolValue < maxSymbolValue)
+        return 0;
+
+    for (s = 0; s <= (int)maxSymbolValue; ++s) {
+        bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0);
+    }
+    return !bad;
 }
 
 size_t HUF_compressBound(size_t size) { return HUF_COMPRESSBOUND(size); }
@@ -1024,7 +1057,7 @@ HUF_compress1X_usingCTable_internal_body(void* dst, size_t dstSize,
                                    const void* src, size_t srcSize,
                                    const HUF_CElt* CTable)
 {
-    U32 const tableLog = (U32)CTable[0];
+    U32 const tableLog = HUF_readCTableHeader(CTable).tableLog;
     HUF_CElt const* ct = CTable + 1;
     const BYTE* ip = (const BYTE*) src;
     BYTE* const ostart = (BYTE*)dst;
@@ -1372,12 +1405,6 @@ HUF_compress_internal (void* dst, size_t dstSize,
         huffLog = (U32)maxBits;
         DEBUGLOG(6, "bit distribution completed (%zu symbols)", showCTableBits(table->CTable + 1, maxSymbolValue+1));
     }
-    /* Zero unused symbols in CTable, so we can check it for validity */
-    {
-        size_t const ctableSize = HUF_CTABLE_SIZE_ST(maxSymbolValue);
-        size_t const unusedSize = sizeof(table->CTable) - ctableSize * sizeof(HUF_CElt);
-        ZSTD_memset(table->CTable + ctableSize, 0, unusedSize);
-    }
 
     /* Write table description header */
     {   CHECK_V_F(hSize, HUF_writeCTable_wksp(op, dstSize, table->CTable, maxSymbolValue, huffLog,