]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
[lib] Always load the dictionary in one go
authorNick Terrell <terrelln@fb.com>
Tue, 4 May 2021 05:33:22 +0000 (22:33 -0700)
committerNick Terrell <terrelln@fb.com>
Tue, 4 May 2021 23:45:25 +0000 (16:45 -0700)
Dictionaries larger than `ZSTD_CHUNKSIZE_MAX` used to have to be loaded
in multiple segments. Instead, when we detect large dictionaries, ensure
that we reset the context's indicies. Then, for dictionaries larger than
`ZSTD_CURRENT_MAX - 1`, only load the suffix of the dictionary. Finally,
enable DDS for large dictionaries, since we no longer load in multiple
segments.

This simplifes the dictionary loading code, and reduces opportunities
for non-determinism to slip in.

lib/compress/zstd_compress.c
lib/compress/zstd_compress_internal.h

index 197e8ed8ac8c7145f94b3f44dc4224d109fd3715..14b284af0d4d1f821eb7503dc1e61957169f1957 100644 (file)
@@ -1785,11 +1785,26 @@ static int ZSTD_indexTooCloseToMax(ZSTD_window_t w)
     return (size_t)(w.nextSrc - w.base) > (ZSTD_CURRENT_MAX - ZSTD_INDEXOVERFLOW_MARGIN);
 }
 
+/** ZSTD_dictTooBig():
+ * When dictionaries are larger than ZSTD_CHUNKSIZE_MAX they can't be loaded in
+ * one go generically. So we ensure that in that case we reset the tables to zero,
+ * so that we can load as much of the dictionary as possible.
+ */
+static int ZSTD_dictTooBig(size_t const loadedDictSize)
+{
+    return loadedDictSize > ZSTD_CHUNKSIZE_MAX;
+}
+
 /*! ZSTD_resetCCtx_internal() :
-    note : `params` are assumed fully validated at this stage */
+ * @param loadedDictSize The size of the dictionary to be loaded
+ * into the context, if any. If no dictionary is used, or the
+ * dictionary is being attached / copied, then pass 0.
+ * note : `params` are assumed fully validated at this stage.
+ */
 static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc,
                                       ZSTD_CCtx_params params,
                                       U64 const pledgedSrcSize,
+                                      size_t const loadedDictSize,
                                       ZSTD_compResetPolicy_e const crp,
                                       ZSTD_buffered_policy_e const zbuff)
 {
@@ -1821,8 +1836,9 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc,
         size_t const maxNbLdmSeq = ZSTD_ldm_getMaxNbSeq(params.ldmParams, blockSize);
 
         int const indexTooClose = ZSTD_indexTooCloseToMax(zc->blockState.matchState.window);
+        int const dictTooBig = ZSTD_dictTooBig(loadedDictSize);
         ZSTD_indexResetPolicy_e needsIndexReset =
-            (!indexTooClose && zc->initialized) ? ZSTDirp_continue : ZSTDirp_reset;
+            (indexTooClose || dictTooBig || !zc->initialized) ? ZSTDirp_reset : ZSTDirp_continue;
 
         size_t const neededSpace =
             ZSTD_estimateCCtxSize_usingCCtxParams_internal(
@@ -1937,7 +1953,6 @@ static size_t ZSTD_resetCCtx_internal(ZSTD_CCtx* zc,
             zc->maxNbLdmSequences = maxNbLdmSeq;
 
             ZSTD_window_init(&zc->ldmState.window);
-            ZSTD_window_clear(&zc->ldmState.window);
             zc->ldmState.loadedDictEnd = 0;
         }
 
@@ -2017,6 +2032,7 @@ ZSTD_resetCCtx_byAttachingCDict(ZSTD_CCtx* cctx,
         params.cParams.windowLog = windowLog;
         params.useRowMatchFinder = cdict->useRowMatchFinder;    /* cdict overrides */
         FORWARD_IF_ERROR(ZSTD_resetCCtx_internal(cctx, params, pledgedSrcSize,
+                                                 /* loadedDictSize */ 0,
                                                  ZSTDcrp_makeClean, zbuff), "");
         assert(cctx->appliedParams.cParams.strategy == adjusted_cdict_cParams.strategy);
     }
@@ -2069,6 +2085,7 @@ static size_t ZSTD_resetCCtx_byCopyingCDict(ZSTD_CCtx* cctx,
         params.cParams.windowLog = windowLog;
         params.useRowMatchFinder = cdict->useRowMatchFinder;
         FORWARD_IF_ERROR(ZSTD_resetCCtx_internal(cctx, params, pledgedSrcSize,
+                                                 /* loadedDictSize */ 0,
                                                  ZSTDcrp_leaveDirty, zbuff), "");
         assert(cctx->appliedParams.cParams.strategy == cdict_cParams->strategy);
         assert(cctx->appliedParams.cParams.hashLog == cdict_cParams->hashLog);
@@ -2174,6 +2191,7 @@ static size_t ZSTD_copyCCtx_internal(ZSTD_CCtx* dstCCtx,
         params.useRowMatchFinder = srcCCtx->appliedParams.useRowMatchFinder;
         params.fParams = fParams;
         ZSTD_resetCCtx_internal(dstCCtx, params, pledgedSrcSize,
+                                /* loadedDictSize */ 0,
                                 ZSTDcrp_leaveDirty, zbuff);
         assert(dstCCtx->appliedParams.cParams.windowLog == srcCCtx->appliedParams.cParams.windowLog);
         assert(dstCCtx->appliedParams.cParams.strategy == srcCCtx->appliedParams.cParams.strategy);
@@ -4099,77 +4117,86 @@ static size_t ZSTD_loadDictionaryContent(ZSTD_matchState_t* ms,
 {
     const BYTE* ip = (const BYTE*) src;
     const BYTE* const iend = ip + srcSize;
+    int const loadLdmDict = params->ldmParams.enableLdm && ls != NULL;
+
+    /* Assert that we the ms params match the params we're being given */
+    ZSTD_assertEqualCParams(params->cParams, ms->cParams);
+    if (srcSize > ZSTD_CHUNKSIZE_MAX) {
+        /* Allow the dictionary to set indices up to exactly ZSTD_CURRENT_MAX.
+         * Dictionaries right at the edge will immediately trigger overflow
+         * correction, but I don't want to insert extra constraints here.
+         */
+        U32 const maxDictSize = ZSTD_CURRENT_MAX - 1;
+        /* We must have cleared our windows when our source is this large. */
+        assert(ZSTD_window_isEmpty(ms->window));
+        if (loadLdmDict)
+            assert(ZSTD_window_isEmpty(ls->window));
+        /* If the dictionary is too large, only load the suffix of the dictionary. */
+        if (srcSize > maxDictSize) {
+            ip = iend - maxDictSize;
+            src = ip;
+            srcSize = maxDictSize;
+        }
+    }
 
     DEBUGLOG(4, "ZSTD_loadDictionaryContent(): useRowMatchFinder=%d", (int)params->useRowMatchFinder);
     ZSTD_window_update(&ms->window, src, srcSize);
     ms->loadedDictEnd = params->forceWindow ? 0 : (U32)(iend - ms->window.base);
 
-    if (params->ldmParams.enableLdm && ls != NULL) {
+    if (loadLdmDict) {
         ZSTD_window_update(&ls->window, src, srcSize);
         ls->loadedDictEnd = params->forceWindow ? 0 : (U32)(iend - ls->window.base);
     }
 
-    /* Assert that we the ms params match the params we're being given */
-    ZSTD_assertEqualCParams(params->cParams, ms->cParams);
-
     if (srcSize <= HASH_READ_SIZE) return 0;
 
-    while (iend - ip > HASH_READ_SIZE) {
-        size_t const remaining = (size_t)(iend - ip);
-        size_t const chunk = MIN(remaining, ZSTD_CHUNKSIZE_MAX);
-        const BYTE* const ichunk = ip + chunk;
-
-        ZSTD_overflowCorrectIfNeeded(ms, ws, params, ip, ichunk);
+    ZSTD_overflowCorrectIfNeeded(ms, ws, params, ip, iend);
 
-        if (params->ldmParams.enableLdm && ls != NULL)
-            ZSTD_ldm_fillHashTable(ls, (const BYTE*)src, (const BYTE*)src + srcSize, &params->ldmParams);
+    if (loadLdmDict)
+        ZSTD_ldm_fillHashTable(ls, (const BYTE*)src, (const BYTE*)src + srcSize, &params->ldmParams);
 
-        switch(params->cParams.strategy)
-        {
-        case ZSTD_fast:
-            ZSTD_fillHashTable(ms, ichunk, dtlm);
-            break;
-        case ZSTD_dfast:
-            ZSTD_fillDoubleHashTable(ms, ichunk, dtlm);
-            break;
+    switch(params->cParams.strategy)
+    {
+    case ZSTD_fast:
+        ZSTD_fillHashTable(ms, iend, dtlm);
+        break;
+    case ZSTD_dfast:
+        ZSTD_fillDoubleHashTable(ms, iend, dtlm);
+        break;
 
-        case ZSTD_greedy:
-        case ZSTD_lazy:
-        case ZSTD_lazy2:
-            if (chunk >= HASH_READ_SIZE) {
-                if (ms->dedicatedDictSearch) {
-                    assert(chunk == remaining); /* must load everything in one go */
-                    assert(ms->chainTable != NULL);
-                    ZSTD_dedicatedDictSearch_lazy_loadDictionary(ms, ichunk-HASH_READ_SIZE);
+    case ZSTD_greedy:
+    case ZSTD_lazy:
+    case ZSTD_lazy2:
+        if (srcSize >= HASH_READ_SIZE) {
+            if (ms->dedicatedDictSearch) {
+                assert(ms->chainTable != NULL);
+                ZSTD_dedicatedDictSearch_lazy_loadDictionary(ms, iend-HASH_READ_SIZE);
+            } else {
+                assert(params->useRowMatchFinder != ZSTD_urm_auto);
+                if (params->useRowMatchFinder == ZSTD_urm_enableRowMatchFinder) {
+                    size_t const tagTableSize = ((size_t)1 << params->cParams.hashLog) * sizeof(U16);
+                    if (ip == src)
+                        ZSTD_memset(ms->tagTable, 0, tagTableSize);
+                    ZSTD_row_update(ms, iend-HASH_READ_SIZE);
+                    DEBUGLOG(4, "Using row-based hash table for lazy dict");
                 } else {
-                    assert(params->useRowMatchFinder != ZSTD_urm_auto);
-                    if (params->useRowMatchFinder == ZSTD_urm_enableRowMatchFinder) {
-                        size_t const tagTableSize = ((size_t)1 << params->cParams.hashLog) * sizeof(U16);
-                        if (ip == src)
-                            ZSTD_memset(ms->tagTable, 0, tagTableSize);
-                        ZSTD_row_update(ms, ichunk-HASH_READ_SIZE);
-                        DEBUGLOG(4, "Using row-based hash table for lazy dict");
-                    } else {
-                        ZSTD_insertAndFindFirstIndex(ms, ichunk-HASH_READ_SIZE);
-                        DEBUGLOG(4, "Using chain-based hash table for lazy dict");
-                    }
+                    ZSTD_insertAndFindFirstIndex(ms, iend-HASH_READ_SIZE);
+                    DEBUGLOG(4, "Using chain-based hash table for lazy dict");
                 }
             }
-            break;
-
-        case ZSTD_btlazy2:   /* we want the dictionary table fully sorted */
-        case ZSTD_btopt:
-        case ZSTD_btultra:
-        case ZSTD_btultra2:
-            if (chunk >= HASH_READ_SIZE)
-                ZSTD_updateTree(ms, ichunk-HASH_READ_SIZE, ichunk);
-            break;
-
-        default:
-            assert(0);  /* not possible : not a valid strategy id */
         }
+        break;
+
+    case ZSTD_btlazy2:   /* we want the dictionary table fully sorted */
+    case ZSTD_btopt:
+    case ZSTD_btultra:
+    case ZSTD_btultra2:
+        if (srcSize >= HASH_READ_SIZE)
+            ZSTD_updateTree(ms, iend-HASH_READ_SIZE, iend);
+        break;
 
-        ip = ichunk;
+    default:
+        assert(0);  /* not possible : not a valid strategy id */
     }
 
     ms->nextToUpdate = (U32)(iend - ms->window.base);
@@ -4378,6 +4405,7 @@ static size_t ZSTD_compressBegin_internal(ZSTD_CCtx* cctx,
                                     const ZSTD_CCtx_params* params, U64 pledgedSrcSize,
                                     ZSTD_buffered_policy_e zbuff)
 {
+    size_t const dictContentSize = cdict ? cdict->dictContentSize : dictSize;
 #if ZSTD_TRACE
     cctx->traceCtx = (ZSTD_trace_compress_begin != NULL) ? ZSTD_trace_compress_begin(cctx) : 0;
 #endif
@@ -4396,6 +4424,7 @@ static size_t ZSTD_compressBegin_internal(ZSTD_CCtx* cctx,
     }
 
     FORWARD_IF_ERROR( ZSTD_resetCCtx_internal(cctx, *params, pledgedSrcSize,
+                                     dictContentSize,
                                      ZSTDcrp_makeClean, zbuff) , "");
     {   size_t const dictID = cdict ?
                 ZSTD_compress_insertDictionary(
@@ -4410,7 +4439,7 @@ static size_t ZSTD_compressBegin_internal(ZSTD_CCtx* cctx,
         FORWARD_IF_ERROR(dictID, "ZSTD_compress_insertDictionary failed");
         assert(dictID <= UINT_MAX);
         cctx->dictID = (U32)dictID;
-        cctx->dictContentSize = cdict ? cdict->dictContentSize : dictSize;
+        cctx->dictContentSize = dictContentSize;
     }
     return 0;
 }
@@ -4680,9 +4709,6 @@ static size_t ZSTD_initCDict_internal(
     assert(!ZSTD_checkCParams(params.cParams));
     cdict->matchState.cParams = params.cParams;
     cdict->matchState.dedicatedDictSearch = params.enableDedicatedDictSearch;
-    if (cdict->matchState.dedicatedDictSearch && dictSize > ZSTD_CHUNKSIZE_MAX) {
-        cdict->matchState.dedicatedDictSearch = 0;
-    }
     if ((dictLoadMethod == ZSTD_dlm_byRef) || (!dictBuffer) || (!dictSize)) {
         cdict->dictContent = dictBuffer;
     } else {
index 3e6c576d1eff7541153f4e0dc6c8b4e3d1eec62b..c0b01794876ade2d22da7a7a05d321ee5ea45539 100644 (file)
@@ -872,6 +872,13 @@ MEM_STATIC void ZSTD_window_clear(ZSTD_window_t* window)
     window->dictLimit = end;
 }
 
+MEM_STATIC U32 ZSTD_window_isEmpty(ZSTD_window_t const window)
+{
+    return window.dictLimit == 1 &&
+           window.lowLimit == 1 &&
+           (window.nextSrc - window.base) == 1;
+}
+
 /**
  * ZSTD_window_hasExtDict():
  * Returns non-zero if the window has a non-empty extDict.