]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
[lib] Allow ZSTD_CCtx_loadDictionary() to be called before parameters are set 1555/head
authorNick Terrell <terrelln@fb.com>
Thu, 21 Mar 2019 22:17:41 +0000 (15:17 -0700)
committerNick Terrell <terrelln@fb.com>
Thu, 21 Mar 2019 23:13:53 +0000 (16:13 -0700)
* After loading a dictionary only create the cdict once we've started the
  compression job. This allows the user to pass the dictionary before they
  set other settings, and is in line with the rest of the API.
* Add tests that mix the 3 dictionary loading APIs.
* Add extra tests for `ZSTD_CCtx_loadDictionary()`.
* The first 2 tests added fail before this patch.
* Run the regression test suite.

lib/compress/zstd_compress.c
lib/compress/zstd_compress_internal.h
tests/fuzzer.c

index 62dab80033a4266a428d9e412ef2e450f0d77ca0..fe59f522fd264fffb37a32176ef06fb8749cd63c 100644 (file)
@@ -103,12 +103,31 @@ ZSTD_CCtx* ZSTD_initStaticCCtx(void *workspace, size_t workspaceSize)
     return cctx;
 }
 
+/**
+ * Clears and frees all of the dictionaries in the CCtx.
+ */
+static void ZSTD_clearAllDicts(ZSTD_CCtx* cctx)
+{
+    ZSTD_free(cctx->localDict.dictBuffer, cctx->customMem);
+    ZSTD_freeCDict(cctx->localDict.cdict);
+    memset(&cctx->localDict, 0, sizeof(cctx->localDict));
+    memset(&cctx->prefixDict, 0, sizeof(cctx->prefixDict));
+    cctx->cdict = NULL;
+}
+
+static size_t ZSTD_sizeof_localDict(ZSTD_localDict dict)
+{
+    size_t const bufferSize = dict.dictBuffer != NULL ? dict.dictSize : 0;
+    size_t const cdictSize = ZSTD_sizeof_CDict(dict.cdict);
+    return bufferSize + cdictSize;
+}
+
 static void ZSTD_freeCCtxContent(ZSTD_CCtx* cctx)
 {
     assert(cctx != NULL);
     assert(cctx->staticSize == 0);
     ZSTD_free(cctx->workSpace, cctx->customMem); cctx->workSpace = NULL;
-    ZSTD_freeCDict(cctx->cdictLocal); cctx->cdictLocal = NULL;
+    ZSTD_clearAllDicts(cctx);
 #ifdef ZSTD_MULTITHREAD
     ZSTDMT_freeCCtx(cctx->mtctx); cctx->mtctx = NULL;
 #endif
@@ -140,7 +159,7 @@ size_t ZSTD_sizeof_CCtx(const ZSTD_CCtx* cctx)
 {
     if (cctx==NULL) return 0;   /* support sizeof on NULL */
     return sizeof(*cctx) + cctx->workSpaceSize
-           + ZSTD_sizeof_CDict(cctx->cdictLocal)
+           + ZSTD_sizeof_localDict(cctx->localDict)
            + ZSTD_sizeof_mtctx(cctx);
 }
 
@@ -785,6 +804,44 @@ ZSTDLIB_API size_t ZSTD_CCtx_setPledgedSrcSize(ZSTD_CCtx* cctx, unsigned long lo
     return 0;
 }
 
+/**
+ * Initializes the local dict using the requested parameters.
+ * NOTE: This does not use the pledged src size, because it may be used for more
+ * than one compression.
+ */
+static size_t ZSTD_initLocalDict(ZSTD_CCtx* cctx)
+{
+    ZSTD_localDict* const dl = &cctx->localDict;
+    ZSTD_compressionParameters const cParams = ZSTD_getCParamsFromCCtxParams(
+            &cctx->requestedParams, 0, dl->dictSize);
+    if (dl->dict == NULL) {
+        /* No local dictionary. */
+        assert(dl->dictBuffer == NULL);
+        assert(dl->cdict == NULL);
+        assert(dl->dictSize == 0);
+        return 0;
+    }
+    if (dl->cdict != NULL) {
+        assert(cctx->cdict == dl->cdict);
+        /* Local dictionary already initialized. */
+        return 0;
+    }
+    assert(dl->dictSize > 0);
+    assert(cctx->cdict == NULL);
+    assert(cctx->prefixDict.dict == NULL);
+
+    dl->cdict = ZSTD_createCDict_advanced(
+            dl->dict,
+            dl->dictSize,
+            ZSTD_dlm_byRef,
+            dl->dictContentType,
+            cParams,
+            cctx->customMem);
+    RETURN_ERROR_IF(!dl->cdict, memory_allocation);
+    cctx->cdict = dl->cdict;
+    return 0;
+}
+
 size_t ZSTD_CCtx_loadDictionary_advanced(
         ZSTD_CCtx* cctx, const void* dict, size_t dictSize,
         ZSTD_dictLoadMethod_e dictLoadMethod, ZSTD_dictContentType_e dictContentType)
@@ -793,20 +850,20 @@ size_t ZSTD_CCtx_loadDictionary_advanced(
     RETURN_ERROR_IF(cctx->staticSize, memory_allocation,
                     "no malloc for static CCtx");
     DEBUGLOG(4, "ZSTD_CCtx_loadDictionary_advanced (size: %u)", (U32)dictSize);
-    ZSTD_freeCDict(cctx->cdictLocal);  /* in case one already exists */
-    if (dict==NULL || dictSize==0) {   /* no dictionary mode */
-        cctx->cdictLocal = NULL;
-        cctx->cdict = NULL;
+    ZSTD_clearAllDicts(cctx);  /* in case one already exists */
+    if (dict == NULL || dictSize == 0)  /* no dictionary mode */
+        return 0;
+    if (dictLoadMethod == ZSTD_dlm_byRef) {
+        cctx->localDict.dict = dict;
     } else {
-        ZSTD_compressionParameters const cParams =
-                ZSTD_getCParamsFromCCtxParams(&cctx->requestedParams, cctx->pledgedSrcSizePlusOne-1, dictSize);
-        cctx->cdictLocal = ZSTD_createCDict_advanced(
-                                dict, dictSize,
-                                dictLoadMethod, dictContentType,
-                                cParams, cctx->customMem);
-        cctx->cdict = cctx->cdictLocal;
-        RETURN_ERROR_IF(cctx->cdictLocal == NULL, memory_allocation);
+        void* dictBuffer = ZSTD_malloc(dictSize, cctx->customMem);
+        RETURN_ERROR_IF(!dictBuffer, memory_allocation);
+        memcpy(dictBuffer, dict, dictSize);
+        cctx->localDict.dictBuffer = dictBuffer;
+        cctx->localDict.dict = dictBuffer;
     }
+    cctx->localDict.dictSize = dictSize;
+    cctx->localDict.dictContentType = dictContentType;
     return 0;
 }
 
@@ -828,10 +885,8 @@ size_t ZSTD_CCtx_refCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict)
 {
     RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong);
     /* Free the existing local cdict (if any) to save memory. */
-    FORWARD_IF_ERROR( ZSTD_freeCDict(cctx->cdictLocal) );
-    cctx->cdictLocal = NULL;
+    ZSTD_clearAllDicts(cctx);
     cctx->cdict = cdict;
-    memset(&cctx->prefixDict, 0, sizeof(cctx->prefixDict));  /* exclusive */
     return 0;
 }
 
@@ -844,7 +899,7 @@ size_t ZSTD_CCtx_refPrefix_advanced(
         ZSTD_CCtx* cctx, const void* prefix, size_t prefixSize, ZSTD_dictContentType_e dictContentType)
 {
     RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong);
-    cctx->cdict = NULL;   /* prefix discards any prior cdict */
+    ZSTD_clearAllDicts(cctx);
     cctx->prefixDict.dict = prefix;
     cctx->prefixDict.dictSize = prefixSize;
     cctx->prefixDict.dictContentType = dictContentType;
@@ -863,7 +918,7 @@ size_t ZSTD_CCtx_reset(ZSTD_CCtx* cctx, ZSTD_ResetDirective reset)
     if ( (reset == ZSTD_reset_parameters)
       || (reset == ZSTD_reset_session_and_parameters) ) {
         RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong);
-        cctx->cdict = NULL;
+        ZSTD_clearAllDicts(cctx);
         return ZSTD_CCtxParams_reset(&cctx->requestedParams);
     }
     return 0;
@@ -4079,6 +4134,7 @@ size_t ZSTD_compressStream2( ZSTD_CCtx* cctx,
     if (cctx->streamStage == zcss_init) {
         ZSTD_CCtx_params params = cctx->requestedParams;
         ZSTD_prefixDict const prefixDict = cctx->prefixDict;
+        FORWARD_IF_ERROR( ZSTD_initLocalDict(cctx) ); /* Init the local dict if present. */
         memset(&cctx->prefixDict, 0, sizeof(cctx->prefixDict));   /* single usage */
         assert(prefixDict.dict==NULL || cctx->cdict==NULL);    /* only one can be set */
         DEBUGLOG(4, "ZSTD_compressStream2 : transparent init stage");
index c90034d4f52f7bbf014c46f021d2c49227bbce70..78b53550b0ca81f976c94b256ef1ceb54e46ad64 100644 (file)
@@ -54,6 +54,14 @@ typedef struct ZSTD_prefixDict_s {
     ZSTD_dictContentType_e dictContentType;
 } ZSTD_prefixDict;
 
+typedef struct {
+    void* dictBuffer;
+    void const* dict;
+    size_t dictSize;
+    ZSTD_dictContentType_e dictContentType;
+    ZSTD_CDict* cdict;
+} ZSTD_localDict;
+
 typedef struct {
     U32 CTable[HUF_CTABLE_SIZE_U32(255)];
     HUF_repeat repeatMode;
@@ -245,7 +253,7 @@ struct ZSTD_CCtx_s {
     U32    frameEnded;
 
     /* Dictionary */
-    ZSTD_CDict* cdictLocal;
+    ZSTD_localDict localDict;
     const ZSTD_CDict* cdict;
     ZSTD_prefixDict prefixDict;   /* single-usage dictionary */
 
index f74ca5d7b3568f07b804cc2f0e86f9ad33525313..7bc2f10cbb2dcfe238a520abc7ef5f4f56c8b612 100644 (file)
@@ -1285,9 +1285,13 @@ static int basicUnitTests(U32 seed, double compressibility)
         {
             size_t ret;
             MEM_writeLE32((char*)dictBuffer+2, ZSTD_MAGIC_DICTIONARY);
+            /* Either operation is allowed to fail, but one must fail. */
             ret = ZSTD_CCtx_loadDictionary_advanced(
                     cctx, (const char*)dictBuffer+2, dictSize-2, ZSTD_dlm_byRef, ZSTD_dct_auto);
-            if (!ZSTD_isError(ret)) goto _output_error;
+            if (!ZSTD_isError(ret)) {
+                ret = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100));
+                if (!ZSTD_isError(ret)) goto _output_error;
+            }
         }
         DISPLAYLEVEL(3, "OK \n");
 
@@ -1298,6 +1302,8 @@ static int basicUnitTests(U32 seed, double compressibility)
             ret = ZSTD_CCtx_loadDictionary_advanced(
                     cctx, (const char*)dictBuffer+2, dictSize-2, ZSTD_dlm_byRef, ZSTD_dct_rawContent);
             if (ZSTD_isError(ret)) goto _output_error;
+            ret = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100));
+            if (ZSTD_isError(ret)) goto _output_error;
         }
         DISPLAYLEVEL(3, "OK \n");
 
@@ -1312,6 +1318,113 @@ static int basicUnitTests(U32 seed, double compressibility)
         }
         DISPLAYLEVEL(3, "OK \n");
 
+        DISPLAYLEVEL(3, "test%3i : Loading dictionary before setting parameters is the same as loading after : ", testNb++);
+        {
+            size_t size1, size2;
+            ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
+            CHECK_Z( ZSTD_CCtx_setParameter(cctx, ZSTD_c_compressionLevel, 7) );
+            CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, CNBuffer, MIN(CNBuffSize, 10 KB)) );
+            size1 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
+            if (ZSTD_isError(size1)) goto _output_error;
+
+            ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
+            CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, CNBuffer, MIN(CNBuffSize, 10 KB)) );
+            CHECK_Z( ZSTD_CCtx_setParameter(cctx, ZSTD_c_compressionLevel, 7) );
+            size2 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
+            if (ZSTD_isError(size2)) goto _output_error;
+
+            if (size1 != size2) goto _output_error;
+        }
+        DISPLAYLEVEL(3, "OK \n");
+
+        DISPLAYLEVEL(3, "test%3i : Loading a dictionary clears the prefix : ", testNb++);
+        {
+            CHECK_Z( ZSTD_CCtx_refPrefix(cctx, (const char*)dictBuffer, dictSize) );
+            CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, (const char*)dictBuffer, dictSize) );
+            CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
+        }
+        DISPLAYLEVEL(3, "OK \n");
+
+        DISPLAYLEVEL(3, "test%3i : Loading a dictionary clears the cdict : ", testNb++);
+        {
+            ZSTD_CDict* const cdict = ZSTD_createCDict(dictBuffer, dictSize, 1);
+            CHECK_Z( ZSTD_CCtx_refCDict(cctx, cdict) );
+            CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, (const char*)dictBuffer, dictSize) );
+            CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
+            ZSTD_freeCDict(cdict);
+        }
+        DISPLAYLEVEL(3, "OK \n");
+
+        DISPLAYLEVEL(3, "test%3i : Loading a cdict clears the prefix : ", testNb++);
+        {
+            ZSTD_CDict* const cdict = ZSTD_createCDict(dictBuffer, dictSize, 1);
+            CHECK_Z( ZSTD_CCtx_refPrefix(cctx, (const char*)dictBuffer, dictSize) );
+            CHECK_Z( ZSTD_CCtx_refCDict(cctx, cdict) );
+            CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
+            ZSTD_freeCDict(cdict);
+        }
+        DISPLAYLEVEL(3, "OK \n");
+
+        DISPLAYLEVEL(3, "test%3i : Loading a cdict clears the dictionary : ", testNb++);
+        {
+            ZSTD_CDict* const cdict = ZSTD_createCDict(dictBuffer, dictSize, 1);
+            CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, (const char*)dictBuffer, dictSize) );
+            CHECK_Z( ZSTD_CCtx_refCDict(cctx, cdict) );
+            CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
+            ZSTD_freeCDict(cdict);
+        }
+        DISPLAYLEVEL(3, "OK \n");
+
+        DISPLAYLEVEL(3, "test%3i : Loading a prefix clears the dictionary : ", testNb++);
+        {
+            CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, (const char*)dictBuffer, dictSize) );
+            CHECK_Z( ZSTD_CCtx_refPrefix(cctx, (const char*)dictBuffer, dictSize) );
+            CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
+        }
+        DISPLAYLEVEL(3, "OK \n");
+
+        DISPLAYLEVEL(3, "test%3i : Loading a prefix clears the cdict : ", testNb++);
+        {
+            ZSTD_CDict* const cdict = ZSTD_createCDict(dictBuffer, dictSize, 1);
+            CHECK_Z( ZSTD_CCtx_refCDict(cctx, cdict) );
+            CHECK_Z( ZSTD_CCtx_refPrefix(cctx, (const char*)dictBuffer, dictSize) );
+            CHECK_Z( ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100)) );
+            ZSTD_freeCDict(cdict);
+        }
+        DISPLAYLEVEL(3, "OK \n");
+
+        DISPLAYLEVEL(3, "test%3i : Loaded dictionary persists across reset session : ", testNb++);
+        {
+            size_t size1, size2;
+            ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
+            CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, CNBuffer, MIN(CNBuffSize, 10 KB)) );
+            size1 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
+            if (ZSTD_isError(size1)) goto _output_error;
+
+            ZSTD_CCtx_reset(cctx, ZSTD_reset_session_only);
+            size2 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
+            if (ZSTD_isError(size2)) goto _output_error;
+
+            if (size1 != size2) goto _output_error;
+        }
+        DISPLAYLEVEL(3, "OK \n");
+
+        DISPLAYLEVEL(3, "test%3i : Loaded dictionary is cleared after resetting parameters : ", testNb++);
+        {
+            size_t size1, size2;
+            ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
+            CHECK_Z( ZSTD_CCtx_loadDictionary(cctx, CNBuffer, MIN(CNBuffSize, 10 KB)) );
+            size1 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
+            if (ZSTD_isError(size1)) goto _output_error;
+
+            ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
+            size2 = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, MIN(CNBuffSize, 100 KB));
+            if (ZSTD_isError(size2)) goto _output_error;
+
+            if (size1 == size2) goto _output_error;
+        }
+        DISPLAYLEVEL(3, "OK \n");
+
         DISPLAYLEVEL(3, "test%3i : Dictionary with non-default repcodes : ", testNb++);
         { U32 u; for (u=0; u<nbSamples; u++) samplesSizes[u] = sampleUnitSize; }
         dictSize = ZDICT_trainFromBuffer(dictBuffer, dictSize,