]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Add salt into row hash (#3528 part 2) (#3533)
authorYonatan Komornik <11005061+yoniko@users.noreply.github.com>
Mon, 13 Mar 2023 22:34:13 +0000 (15:34 -0700)
committerGitHub <noreply@github.com>
Mon, 13 Mar 2023 22:34:13 +0000 (15:34 -0700)
Part 2 of #3528

Adds hash salt that helps to avoid regressions where consecutive compressions use the same tag space with similar data (running zstd -b5e7 enwik8 -B128K reproduces this regression).

lib/common/bits.h
lib/compress/zstd_compress.c
lib/compress/zstd_compress_internal.h
lib/compress/zstd_lazy.c

index 4a9bbf58c5d89da16ed88577cfe27b0918d769f5..def56c474c380d9a734fbc249c1029e2a9c5ec54 100644 (file)
@@ -172,4 +172,29 @@ MEM_STATIC unsigned ZSTD_highbit32(U32 val)   /* compress, dictBuilder, decodeCo
     return 31 - ZSTD_countLeadingZeros32(val);
 }
 
+/* ZSTD_rotateRight_*():
+ * Rotates a bitfield to the right by "count" bits.
+ * https://en.wikipedia.org/w/index.php?title=Circular_shift&oldid=991635599#Implementing_circular_shifts
+ */
+MEM_STATIC
+U64 ZSTD_rotateRight_U64(U64 const value, U32 count) {
+    assert(count < 64);
+    count &= 0x3F; /* for fickle pattern recognition */
+    return (value >> count) | (U64)(value << ((0U - count) & 0x3F));
+}
+
+MEM_STATIC
+U32 ZSTD_rotateRight_U32(U32 const value, U32 count) {
+    assert(count < 32);
+    count &= 0x1F; /* for fickle pattern recognition */
+    return (value >> count) | (U32)(value << ((0U - count) & 0x1F));
+}
+
+MEM_STATIC
+U16 ZSTD_rotateRight_U16(U16 const value, U32 count) {
+    assert(count < 16);
+    count &= 0x0F; /* for fickle pattern recognition */
+    return (value >> count) | (U16)(value << ((0U - count) & 0x0F));
+}
+
 #endif /* ZSTD_BITS_H */
index 49aa385466d3eb404dde1f61420e6c076fafac60..77cb82e2178bfb405d5209493f54825c2c719cb8 100644 (file)
@@ -27,7 +27,7 @@
 #include "zstd_opt.h"
 #include "zstd_ldm.h"
 #include "zstd_compress_superblock.h"
-#include  "../common/bits.h"      /* ZSTD_highbit32 */
+#include  "../common/bits.h"      /* ZSTD_highbit32, ZSTD_rotateRight_U64 */
 
 /* ***************************************************************
 *  Tuning parameters
@@ -1907,6 +1907,19 @@ typedef enum {
     ZSTD_resetTarget_CCtx
 } ZSTD_resetTarget_e;
 
+/* Mixes bits in a 64 bits in a value, based on XXH3_rrmxmx */
+static U64 ZSTD_bitmix(U64 val, U64 len) {
+    val ^= ZSTD_rotateRight_U64(val, 49) ^ ZSTD_rotateRight_U64(val, 24);
+    val *= 0x9FB21C651E98DF25ULL;
+    val ^= (val >> 35) + len ;
+    val *= 0x9FB21C651E98DF25ULL;
+    return val ^ (val >> 28);
+}
+
+/* Mixes in the hashSalt and hashSaltEntropy to create a new hashSalt */
+static void ZSTD_advanceHashSalt(ZSTD_matchState_t* ms) {
+    ms->hashSalt = ZSTD_bitmix(ms->hashSalt, 8) ^ ZSTD_bitmix((U64) ms->hashSaltEntropy, 4);
+}
 
 static size_t
 ZSTD_reset_matchState(ZSTD_matchState_t* ms,
@@ -1958,7 +1971,17 @@ ZSTD_reset_matchState(ZSTD_matchState_t* ms,
     if (ZSTD_rowMatchFinderUsed(cParams->strategy, useRowMatchFinder)) {
         /* Row match finder needs an additional table of hashes ("tags") */
         size_t const tagTableSize = hSize;
-        ms->tagTable = (BYTE*) ZSTD_cwksp_reserve_aligned_init_once(ws, tagTableSize);
+        /* We want to generate a new salt in case we reset a Cctx, but we always want to use
+         * 0 when we reset a Cdict */
+        if(forWho == ZSTD_resetTarget_CCtx) {
+            ms->tagTable = (BYTE*) ZSTD_cwksp_reserve_aligned_init_once(ws, tagTableSize);
+            ZSTD_advanceHashSalt(ms);
+        } else {
+            /* When we are not salting we want to always memset the memory */
+            ms->tagTable = (BYTE*) ZSTD_cwksp_reserve_aligned(ws, tagTableSize);
+            ZSTD_memset(ms->tagTable, 0, tagTableSize);
+            ms->hashSalt = 0;
+        }
         {   /* Switch to 32-entry rows if searchLog is 5 (or more) */
             U32 const rowLog = BOUNDED(4, cParams->searchLog, 6);
             assert(cParams->hashLog >= rowLog);
@@ -2364,8 +2387,9 @@ static size_t ZSTD_resetCCtx_byCopyingCDict(ZSTD_CCtx* cctx,
         if (ZSTD_rowMatchFinderUsed(cdict_cParams->strategy, cdict->useRowMatchFinder)) {
             size_t const tagTableSize = hSize;
             ZSTD_memcpy(cctx->blockState.matchState.tagTable,
-                cdict->matchState.tagTable,
-                tagTableSize);
+                        cdict->matchState.tagTable,
+                        tagTableSize);
+            cctx->blockState.matchState.hashSalt = cdict->matchState.hashSalt;
         }
     }
 
index 789820ad0c9f1794116c6cccc067217c1df50c51..0435e4c37231fbcaf9bf14775d3836e1d489d901 100644 (file)
@@ -228,6 +228,8 @@ struct ZSTD_matchState_t {
     U32 rowHashLog;                          /* For row-based matchfinder: Hashlog based on nb of rows in the hashTable.*/
     BYTE* tagTable;                          /* For row-based matchFinder: A row-based table containing the hashes and head index. */
     U32 hashCache[ZSTD_ROW_HASH_CACHE_SIZE]; /* For row-based matchFinder: a cache of hashes to improve speed */
+    U64 hashSalt;                            /* For row-based matchFinder: salts the hash for re-use of tag table */
+    U32 hashSaltEntropy;                     /* For row-based matchFinder: collects entropy for salt generation */
 
     U32* hashTable;
     U32* hashTable3;
@@ -787,28 +789,35 @@ ZSTD_count_2segments(const BYTE* ip, const BYTE* match,
  *  Hashes
  ***************************************/
 static const U32 prime3bytes = 506832829U;
-static U32    ZSTD_hash3(U32 u, U32 h) { assert(h <= 32); return ((u << (32-24)) * prime3bytes)  >> (32-h) ; }
-MEM_STATIC size_t ZSTD_hash3Ptr(const void* ptr, U32 h) { return ZSTD_hash3(MEM_readLE32(ptr), h); } /* only in zstd_opt.h */
+static U32    ZSTD_hash3(U32 u, U32 h, U32 s) { assert(h <= 32); return (((u << (32-24)) * prime3bytes) ^ s)  >> (32-h) ; }
+MEM_STATIC size_t ZSTD_hash3Ptr(const void* ptr, U32 h) { return ZSTD_hash3(MEM_readLE32(ptr), h, 0); } /* only in zstd_opt.h */
+MEM_STATIC size_t ZSTD_hash3PtrS(const void* ptr, U32 h, U32 s) { return ZSTD_hash3(MEM_readLE32(ptr), h, s); }
 
 static const U32 prime4bytes = 2654435761U;
-static U32    ZSTD_hash4(U32 u, U32 h) { assert(h <= 32); return (u * prime4bytes) >> (32-h) ; }
-static size_t ZSTD_hash4Ptr(const void* ptr, U32 h) { return ZSTD_hash4(MEM_readLE32(ptr), h); }
+static U32    ZSTD_hash4(U32 u, U32 h, U32 s) { assert(h <= 32); return ((u * prime4bytes) ^ s) >> (32-h) ; }
+static size_t ZSTD_hash4Ptr(const void* ptr, U32 h) { return ZSTD_hash4(MEM_readLE32(ptr), h, 0); }
+static size_t ZSTD_hash4PtrS(const void* ptr, U32 h, U32 s) { return ZSTD_hash4(MEM_readLE32(ptr), h, s); }
 
 static const U64 prime5bytes = 889523592379ULL;
-static size_t ZSTD_hash5(U64 u, U32 h) { assert(h <= 64); return (size_t)(((u  << (64-40)) * prime5bytes) >> (64-h)) ; }
-static size_t ZSTD_hash5Ptr(const void* p, U32 h) { return ZSTD_hash5(MEM_readLE64(p), h); }
+static size_t ZSTD_hash5(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u  << (64-40)) * prime5bytes) ^ s) >> (64-h)) ; }
+static size_t ZSTD_hash5Ptr(const void* p, U32 h) { return ZSTD_hash5(MEM_readLE64(p), h, 0); }
+static size_t ZSTD_hash5PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash5(MEM_readLE64(p), h, s); }
 
 static const U64 prime6bytes = 227718039650203ULL;
-static size_t ZSTD_hash6(U64 u, U32 h) { assert(h <= 64); return (size_t)(((u  << (64-48)) * prime6bytes) >> (64-h)) ; }
-static size_t ZSTD_hash6Ptr(const void* p, U32 h) { return ZSTD_hash6(MEM_readLE64(p), h); }
+static size_t ZSTD_hash6(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u  << (64-48)) * prime6bytes) ^ s) >> (64-h)) ; }
+static size_t ZSTD_hash6Ptr(const void* p, U32 h) { return ZSTD_hash6(MEM_readLE64(p), h, 0); }
+static size_t ZSTD_hash6PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash6(MEM_readLE64(p), h, s); }
 
 static const U64 prime7bytes = 58295818150454627ULL;
-static size_t ZSTD_hash7(U64 u, U32 h) { assert(h <= 64); return (size_t)(((u  << (64-56)) * prime7bytes) >> (64-h)) ; }
-static size_t ZSTD_hash7Ptr(const void* p, U32 h) { return ZSTD_hash7(MEM_readLE64(p), h); }
+static size_t ZSTD_hash7(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u  << (64-56)) * prime7bytes) ^ s) >> (64-h)) ; }
+static size_t ZSTD_hash7Ptr(const void* p, U32 h) { return ZSTD_hash7(MEM_readLE64(p), h, 0); }
+static size_t ZSTD_hash7PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash7(MEM_readLE64(p), h, s); }
 
 static const U64 prime8bytes = 0xCF1BBCDCB7A56463ULL;
-static size_t ZSTD_hash8(U64 u, U32 h) { assert(h <= 64); return (size_t)(((u) * prime8bytes) >> (64-h)) ; }
-static size_t ZSTD_hash8Ptr(const void* p, U32 h) { return ZSTD_hash8(MEM_readLE64(p), h); }
+static size_t ZSTD_hash8(U64 u, U32 h, U64 s) { assert(h <= 64); return (size_t)((((u) * prime8bytes)  ^ s) >> (64-h)) ; }
+static size_t ZSTD_hash8Ptr(const void* p, U32 h) { return ZSTD_hash8(MEM_readLE64(p), h, 0); }
+static size_t ZSTD_hash8PtrS(const void* p, U32 h, U64 s) { return ZSTD_hash8(MEM_readLE64(p), h, s); }
+
 
 MEM_STATIC FORCE_INLINE_ATTR
 size_t ZSTD_hashPtr(const void* p, U32 hBits, U32 mls)
@@ -828,6 +837,24 @@ size_t ZSTD_hashPtr(const void* p, U32 hBits, U32 mls)
     }
 }
 
+MEM_STATIC FORCE_INLINE_ATTR
+size_t ZSTD_hashPtrSalted(const void* p, U32 hBits, U32 mls, const U64 hashSalt) {
+    /* Although some of these hashes do support hBits up to 64, some do not.
+     * To be on the safe side, always avoid hBits > 32. */
+    assert(hBits <= 32);
+
+    switch(mls)
+    {
+        default:
+        case 4: return ZSTD_hash4PtrS(p, hBits, (U32)hashSalt);
+        case 5: return ZSTD_hash5PtrS(p, hBits, hashSalt);
+        case 6: return ZSTD_hash6PtrS(p, hBits, hashSalt);
+        case 7: return ZSTD_hash7PtrS(p, hBits, hashSalt);
+        case 8: return ZSTD_hash8PtrS(p, hBits, hashSalt);
+    }
+}
+
+
 /** ZSTD_ipow() :
  * Return base^exponent.
  */
index d41478e99fe7b0ca1d669316f3c799e61d09bb57..b3e7bf75752c2323a88ceffdf0e8d227b1fa462e 100644 (file)
@@ -773,31 +773,6 @@ MEM_STATIC U32 ZSTD_VecMask_next(ZSTD_VecMask val) {
     return ZSTD_countTrailingZeros64(val);
 }
 
-/* ZSTD_rotateRight_*():
- * Rotates a bitfield to the right by "count" bits.
- * https://en.wikipedia.org/w/index.php?title=Circular_shift&oldid=991635599#Implementing_circular_shifts
- */
-FORCE_INLINE_TEMPLATE
-U64 ZSTD_rotateRight_U64(U64 const value, U32 count) {
-    assert(count < 64);
-    count &= 0x3F; /* for fickle pattern recognition */
-    return (value >> count) | (U64)(value << ((0U - count) & 0x3F));
-}
-
-FORCE_INLINE_TEMPLATE
-U32 ZSTD_rotateRight_U32(U32 const value, U32 count) {
-    assert(count < 32);
-    count &= 0x1F; /* for fickle pattern recognition */
-    return (value >> count) | (U32)(value << ((0U - count) & 0x1F));
-}
-
-FORCE_INLINE_TEMPLATE
-U16 ZSTD_rotateRight_U16(U16 const value, U32 count) {
-    assert(count < 16);
-    count &= 0x0F; /* for fickle pattern recognition */
-    return (value >> count) | (U16)(value << ((0U - count) & 0x0F));
-}
-
 /* ZSTD_row_nextIndex():
  * Returns the next index to insert at within a tagTable row, and updates the "head"
  * value to reflect the update. Essentially cycles backwards from [1, {entries per row})
@@ -850,7 +825,7 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_fillHashCache(ZSTD_matchState_t* ms, const B
     U32 const lim = idx + MIN(ZSTD_ROW_HASH_CACHE_SIZE, maxElemsToPrefetch);
 
     for (; idx < lim; ++idx) {
-        U32 const hash = (U32)ZSTD_hashPtr(base + idx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls);
+        U32 const hash = (U32)ZSTD_hashPtrSalted(base + idx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt);
         U32 const row = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog;
         ZSTD_row_prefetch(hashTable, tagTable, row, rowLog);
         ms->hashCache[idx & ZSTD_ROW_HASH_CACHE_MASK] = hash;
@@ -868,9 +843,10 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_fillHashCache(ZSTD_matchState_t* ms, const B
 FORCE_INLINE_TEMPLATE U32 ZSTD_row_nextCachedHash(U32* cache, U32 const* hashTable,
                                                   BYTE const* tagTable, BYTE const* base,
                                                   U32 idx, U32 const hashLog,
-                                                  U32 const rowLog, U32 const mls)
+                                                  U32 const rowLog, U32 const mls,
+                                                  U64 const hashSalt)
 {
-    U32 const newHash = (U32)ZSTD_hashPtr(base+idx+ZSTD_ROW_HASH_CACHE_SIZE, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls);
+    U32 const newHash = (U32)ZSTD_hashPtrSalted(base+idx+ZSTD_ROW_HASH_CACHE_SIZE, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, hashSalt);
     U32 const row = (newHash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog;
     ZSTD_row_prefetch(hashTable, tagTable, row, rowLog);
     {   U32 const hash = cache[idx & ZSTD_ROW_HASH_CACHE_MASK];
@@ -890,21 +866,24 @@ FORCE_INLINE_TEMPLATE void ZSTD_row_update_internalImpl(ZSTD_matchState_t* ms,
     U32* const hashTable = ms->hashTable;
     BYTE* const tagTable = ms->tagTable;
     U32 const hashLog = ms->rowHashLog;
+    U32 hashSaltEntropyCollected = 0;
     const BYTE* const base = ms->window.base;
 
     DEBUGLOG(6, "ZSTD_row_update_internalImpl(): updateStartIdx=%u, updateEndIdx=%u", updateStartIdx, updateEndIdx);
     for (; updateStartIdx < updateEndIdx; ++updateStartIdx) {
-        U32 const hash = useCache ? ZSTD_row_nextCachedHash(ms->hashCache, hashTable, tagTable, base, updateStartIdx, hashLog, rowLog, mls)
-                                  : (U32)ZSTD_hashPtr(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls);
+        U32 const hash = useCache ? ZSTD_row_nextCachedHash(ms->hashCache, hashTable, tagTable, base, updateStartIdx, hashLog, rowLog, mls, ms->hashSalt)
+                                  : (U32)ZSTD_hashPtrSalted(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt);
         U32 const relRow = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog;
         U32* const row = hashTable + relRow;
         BYTE* tagRow = tagTable + relRow;
         U32 const pos = ZSTD_row_nextIndex(tagRow, rowMask);
 
-        assert(hash == ZSTD_hashPtr(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls));
+        assert(hash == ZSTD_hashPtrSalted(base + updateStartIdx, hashLog + ZSTD_ROW_HASH_TAG_BITS, mls, ms->hashSalt));
         tagRow[pos] = hash & ZSTD_ROW_HASH_TAG_MASK;
         row[pos] = updateStartIdx;
+        hashSaltEntropyCollected = hash;
     }
+    ms->hashSaltEntropy += hashSaltEntropyCollected; /* collect salt entropy */
 }
 
 /* ZSTD_row_update_internal():
@@ -1162,6 +1141,7 @@ size_t ZSTD_RowFindBestMatch(
     const U32 rowMask = rowEntries - 1;
     const U32 cappedSearchLog = MIN(cParams->searchLog, rowLog); /* nb of searches is capped at nb entries per row */
     const U32 groupWidth = ZSTD_row_matchMaskGroupWidth(rowEntries);
+    const U64 hashSalt = ms->hashSalt;
     U32 nbAttempts = 1U << cappedSearchLog;
     size_t ml=4-1;
 
@@ -1199,7 +1179,7 @@ size_t ZSTD_RowFindBestMatch(
     /* Update the hashTable and tagTable up to (but not including) ip */
     ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 1 /* useCache */);
     {   /* Get the hash for ip, compute the appropriate row */
-        U32 const hash = ZSTD_row_nextCachedHash(hashCache, hashTable, tagTable, base, curr, hashLog, rowLog, mls);
+        U32 const hash = ZSTD_row_nextCachedHash(hashCache, hashTable, tagTable, base, curr, hashLog, rowLog, mls, hashSalt);
         U32 const relRow = (hash >> ZSTD_ROW_HASH_TAG_BITS) << rowLog;
         U32 const tag = hash & ZSTD_ROW_HASH_TAG_MASK;
         U32* const row = hashTable + relRow;