]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
fixed ZSTD_loadZstdDictionary() 625/head
authorYann Collet <cyan@fb.com>
Fri, 24 Mar 2017 19:46:46 +0000 (12:46 -0700)
committerYann Collet <cyan@fb.com>
Fri, 24 Mar 2017 19:46:46 +0000 (12:46 -0700)
forgot to add the dictionary content
(tests were not failing, just compressing less).

Also : added size protections when adding dict content
since hc/bt table filling would fail if size < 8

lib/compress/zstd_compress.c

index eecc76950d765400b5efb1de633a7299bed358c4..c1ef526b4ec71303c8b7418cf3fba347c278dfc5 100644 (file)
@@ -2504,7 +2504,9 @@ size_t ZSTD_compressBlock(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const
     return ZSTD_compressContinue_internal(cctx, dst, dstCapacity, src, srcSize, 0, 0);
 }
 
-
+/*! ZSTD_loadDictionaryContent() :
+ *  @return : 0, or an error code
+ */
 static size_t ZSTD_loadDictionaryContent(ZSTD_CCtx* zc, const void* src, size_t srcSize)
 {
     const BYTE* const ip = (const BYTE*) src;
@@ -2534,13 +2536,15 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_CCtx* zc, const void* src, size_t
     case ZSTD_greedy:
     case ZSTD_lazy:
     case ZSTD_lazy2:
-        ZSTD_insertAndFindFirstIndex (zc, iend-HASH_READ_SIZE, zc->params.cParams.searchLength);
+        if (srcSize >= HASH_READ_SIZE)
+            ZSTD_insertAndFindFirstIndex(zc, iend-HASH_READ_SIZE, zc->params.cParams.searchLength);
         break;
 
     case ZSTD_btlazy2:
     case ZSTD_btopt:
     case ZSTD_btopt2:
-        ZSTD_updateTree(zc, iend-HASH_READ_SIZE, iend, 1 << zc->params.cParams.searchLog, zc->params.cParams.searchLength);
+        if (srcSize >= HASH_READ_SIZE)
+            ZSTD_updateTree(zc, iend-HASH_READ_SIZE, iend, 1 << zc->params.cParams.searchLog, zc->params.cParams.searchLength);
         break;
 
     default:
@@ -2570,12 +2574,12 @@ static size_t ZSTD_checkDictNCount(short* normalizedCounter, unsigned dictMaxSym
  * See :
  * https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format
  */
-/*! ZSTD_loadDictionary() :
- * @return : size read from dictionary
- *  note : magic number supposed already checked
- *         dictSize supposed > 8
+/*! ZSTD_loadZstdDictionary() :
+ * @return : 0, or an error code
+ *  assumptions : magic number supposed already checked
+ *                dictSize supposed > 8
  */
-static size_t ZSTD_loadDictionary(ZSTD_CCtx* cctx, const void* dict, size_t dictSize)
+static size_t ZSTD_loadZstdDictionary(ZSTD_CCtx* cctx, const void* dict, size_t dictSize)
 {
     const BYTE* dictPtr = (const BYTE*)dict;
     const BYTE* const dictEnd = dictPtr + dictSize;
@@ -2624,9 +2628,9 @@ static size_t ZSTD_loadDictionary(ZSTD_CCtx* cctx, const void* dict, size_t dict
     }
 
     if (dictPtr+12 > dictEnd) return ERROR(dictionary_corrupted);
-    cctx->rep[0] = MEM_readLE32(dictPtr+0); if (cctx->rep[0] == 0 || cctx->rep[0] >= dictSize) return ERROR(dictionary_corrupted);
-    cctx->rep[1] = MEM_readLE32(dictPtr+4); if (cctx->rep[1] == 0 || cctx->rep[1] >= dictSize) return ERROR(dictionary_corrupted);
-    cctx->rep[2] = MEM_readLE32(dictPtr+8); if (cctx->rep[2] == 0 || cctx->rep[2] >= dictSize) return ERROR(dictionary_corrupted);
+    cctx->rep[0] = MEM_readLE32(dictPtr+0);
+    cctx->rep[1] = MEM_readLE32(dictPtr+4);
+    cctx->rep[2] = MEM_readLE32(dictPtr+8);
     dictPtr += 12;
 
     {   size_t const dictContentSize = (size_t)(dictEnd - dictPtr);
@@ -2642,25 +2646,26 @@ static size_t ZSTD_loadDictionary(ZSTD_CCtx* cctx, const void* dict, size_t dict
             for (u=0; u<3; u++) {
                 if (cctx->rep[u] == 0) return ERROR(dictionary_corrupted);
                 if (cctx->rep[u] > dictContentSize) return ERROR(dictionary_corrupted);
-    }   }   }
+        }   }
 
-    cctx->flagStaticTables = 1;
-    cctx->flagStaticHufTable = HUF_repeat_valid;
-    return dictPtr - (const BYTE*)dict;
+        cctx->flagStaticTables = 1;
+        cctx->flagStaticHufTable = HUF_repeat_valid;
+        return ZSTD_loadDictionaryContent(cctx, dictPtr, dictContentSize);
+    }
 }
 
 /** ZSTD_compress_insertDictionary() :
 *   @return : 0, or an error code */
-static size_t ZSTD_compress_insertDictionary(ZSTD_CCtx* zc, const void* dict, size_t dictSize)
+static size_t ZSTD_compress_insertDictionary(ZSTD_CCtx* cctx, const void* dict, size_t dictSize)
 {
     if ((dict==NULL) || (dictSize<=8)) return 0;
 
     /* dict as pure content */
-    if ((MEM_readLE32(dict) != ZSTD_DICT_MAGIC) || (zc->forceRawDict))
-        return ZSTD_loadDictionaryContent(zc, dict, dictSize);
+    if ((MEM_readLE32(dict) != ZSTD_DICT_MAGIC) || (cctx->forceRawDict))
+        return ZSTD_loadDictionaryContent(cctx, dict, dictSize);
 
     /* dict as zstd dictionary */
-    return ZSTD_loadDictionary(zc, dict, dictSize);
+    return ZSTD_loadZstdDictionary(cctx, dict, dictSize);
 }
 
 /*! ZSTD_compressBegin_internal() :