]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Optimize compression by avoiding unpredictable branches 4144/head
authorIlya Tokar <tokarip@google.com>
Wed, 18 Sep 2024 21:36:37 +0000 (17:36 -0400)
committerIlya Tokar <tokarip@google.com>
Fri, 20 Sep 2024 20:07:01 +0000 (16:07 -0400)
Avoid unpredictable branch. Use conditional move to generate the address
that is guaranteed to be safe and compare unconditionally.
Instead of

if (idx < limit && x[idx] == val ) // mispredicted idx < limit branch

Do

addr = cmov(safe,x+idx)
if (*addr == val && idx < limit) // almost always false so well predicted

Using microbenchmarks from https://github.com/google/fleetbench,
I get about ~10% speed-up:

name                                                                                          old cpu/op   new cpu/op    delta
BM_ZSTD_COMPRESS_Fleet/compression_level:-7/window_log:15                                     1.46ns ± 3%   1.31ns ± 7%   -9.88%  (p=0.000 n=35+38)
BM_ZSTD_COMPRESS_Fleet/compression_level:-7/window_log:16                                     1.41ns ± 3%   1.28ns ± 3%   -9.56%  (p=0.000 n=36+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:-5/window_log:15                                     1.61ns ± 1%   1.43ns ± 3%  -10.70%  (p=0.000 n=30+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:-5/window_log:16                                     1.54ns ± 2%   1.39ns ± 3%   -9.21%  (p=0.000 n=37+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:-3/window_log:15                                     1.82ns ± 2%   1.61ns ± 3%  -11.31%  (p=0.000 n=37+40)
BM_ZSTD_COMPRESS_Fleet/compression_level:-3/window_log:16                                     1.73ns ± 3%   1.56ns ± 3%   -9.50%  (p=0.000 n=38+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:-1/window_log:15                                     2.12ns ± 2%   1.79ns ± 3%  -15.55%  (p=0.000 n=34+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:-1/window_log:16                                     1.99ns ± 3%   1.72ns ± 3%  -13.70%  (p=0.000 n=38+38)
BM_ZSTD_COMPRESS_Fleet/compression_level:0/window_log:15                                      3.22ns ± 3%   2.94ns ± 3%   -8.67%  (p=0.000 n=38+40)
BM_ZSTD_COMPRESS_Fleet/compression_level:0/window_log:16                                      3.19ns ± 4%   2.86ns ± 4%  -10.55%  (p=0.000 n=40+38)
BM_ZSTD_COMPRESS_Fleet/compression_level:1/window_log:15                                      2.60ns ± 3%   2.22ns ± 3%  -14.53%  (p=0.000 n=40+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:1/window_log:16                                      2.46ns ± 3%   2.13ns ± 2%  -13.67%  (p=0.000 n=39+36)
BM_ZSTD_COMPRESS_Fleet/compression_level:2/window_log:15                                      2.69ns ± 3%   2.46ns ± 3%   -8.63%  (p=0.000 n=37+39)
BM_ZSTD_COMPRESS_Fleet/compression_level:2/window_log:16                                      2.63ns ± 3%   2.36ns ± 3%  -10.47%  (p=0.000 n=40+40)
BM_ZSTD_COMPRESS_Fleet/compression_level:3/window_log:15                                      3.20ns ± 2%   2.95ns ± 3%   -7.94%  (p=0.000 n=35+40)
BM_ZSTD_COMPRESS_Fleet/compression_level:3/window_log:16                                      3.20ns ± 4%   2.87ns ± 4%  -10.33%  (p=0.000 n=40+40)

I've also measured the impact on internal workloads and saw similar
~10% improvement in performance, measured by cpu usage/byte of data.

lib/compress/zstd_compress_internal.h
lib/compress/zstd_double_fast.c
lib/compress/zstd_fast.c

index ba1450852ed20042d17b41cd539c108eb7f1841c..cbcded28dbbc65b2bcc18d530e5a6f62fa4b49e0 100644 (file)
@@ -557,6 +557,23 @@ MEM_STATIC int ZSTD_cParam_withinBounds(ZSTD_cParameter cParam, int value)
     return 1;
 }
 
+/* ZSTD_selectAddr:
+ * @return a >= b ? trueAddr : falseAddr,
+ * tries to force branchless codegen. */
+MEM_STATIC const BYTE* ZSTD_selectAddr(U32 a, U32 b, const BYTE* trueAddr, const BYTE* falseAddr) {
+#if defined(__GNUC__) && defined(__x86_64__)
+    __asm__ (
+        "cmp %1, %2\n"
+        "cmova %3, %0\n"
+        : "+r"(trueAddr)
+        : "r"(a), "r"(b), "r"(falseAddr)
+        );
+    return trueAddr;
+#else
+    return a >= b ? trueAddr : falseAddr;
+#endif
+}
+
 /* ZSTD_noCompressBlock() :
  * Writes uncompressed block to dst buffer from given src.
  * Returns the size of the block */
index 1b5f64f20bb35f87f8fec94d72bb2910384298a2..819658caf77bf9c40fbfe1a0648f6ff94aadcef3 100644 (file)
@@ -140,11 +140,17 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic(
     U32 idxl1; /* the long match index for ip1 */
 
     const BYTE* matchl0; /* the long match for ip */
+    const BYTE* matchl0_safe; /* matchl0 or safe address */
     const BYTE* matchs0; /* the short match for ip */
     const BYTE* matchl1; /* the long match for ip1 */
+    const BYTE* matchs0_safe; /* matchs0 or safe address */
 
     const BYTE* ip = istart; /* the current position */
     const BYTE* ip1; /* the next position */
+    /* Array of ~random data, should have low probability of matching data
+     * we load from here instead of from tables, if matchl0/matchl1 are
+     * invalid indices. Used to avoid unpredictable branches. */
+    const BYTE dummy[] = {0x12,0x34,0x56,0x78,0x9a,0xbc,0xde,0xf0,0xe2,0xb4};
 
     DEBUGLOG(5, "ZSTD_compressBlock_doubleFast_noDict_generic");
 
@@ -191,24 +197,29 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic(
 
             hl1 = ZSTD_hashPtr(ip1, hBitsL, 8);
 
-            if (idxl0 > prefixLowestIndex) {
-                /* check prefix long match */
-                if (MEM_read64(matchl0) == MEM_read64(ip)) {
-                    mLength = ZSTD_count(ip+8, matchl0+8, iend) + 8;
-                    offset = (U32)(ip-matchl0);
-                    while (((ip>anchor) & (matchl0>prefixLowest)) && (ip[-1] == matchl0[-1])) { ip--; matchl0--; mLength++; } /* catch up */
-                    goto _match_found;
-                }
+            /* idxl0 > prefixLowestIndex is a (somewhat) unpredictable branch.
+             * However expression below complies into conditional move. Since
+             * match is unlikely and we only *branch* on idxl0 > prefixLowestIndex
+             * if there is a match, all branches become predictable. */
+            matchl0_safe = ZSTD_selectAddr(prefixLowestIndex, idxl0, &dummy[0], matchl0);
+
+            /* check prefix long match */
+            if (MEM_read64(matchl0_safe) == MEM_read64(ip) && matchl0_safe == matchl0) {
+                mLength = ZSTD_count(ip+8, matchl0+8, iend) + 8;
+                offset = (U32)(ip-matchl0);
+                while (((ip>anchor) & (matchl0>prefixLowest)) && (ip[-1] == matchl0[-1])) { ip--; matchl0--; mLength++; } /* catch up */
+                goto _match_found;
             }
 
             idxl1 = hashLong[hl1];
             matchl1 = base + idxl1;
 
-            if (idxs0 > prefixLowestIndex) {
-                /* check prefix short match */
-                if (MEM_read32(matchs0) == MEM_read32(ip)) {
-                    goto _search_next_long;
-                }
+            /* Same optimization as matchl0 above */
+            matchs0_safe = ZSTD_selectAddr(prefixLowestIndex, idxs0, &dummy[0], matchs0);
+
+            /* check prefix short match */
+            if(MEM_read32(matchs0_safe) == MEM_read32(ip) && matchs0_safe == matchs0) {
+                  goto _search_next_long;
             }
 
             if (ip1 >= nextStep) {
index c6baa49d4dd0a576c943be0463b6a7a2c2aff1ea..838a18ee5cff90fc25e1017c53f443cbc1c8131b 100644 (file)
@@ -162,6 +162,11 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
     const BYTE* const prefixStart = base + prefixStartIndex;
     const BYTE* const iend = istart + srcSize;
     const BYTE* const ilimit = iend - HASH_READ_SIZE;
+    /* Array of ~random data, should have low probability of matching data
+     * we load from here instead of from tables, if the index is invalid.
+     * Used to avoid unpredictable branches. */
+    const BYTE dummy[] = {0x12,0x34,0x56,0x78,0x9a,0xbc,0xde,0xf0,0xe2,0xb4};
+    const BYTE *mvalAddr;
 
     const BYTE* anchor = istart;
     const BYTE* ip0 = istart;
@@ -246,15 +251,18 @@ _start: /* Requires: ip0 */
             goto _match;
         }
 
+        /* idx >= prefixStartIndex is a (somewhat) unpredictable branch.
+         * However expression below complies into conditional move. Since
+         * match is unlikely and we only *branch* on idxl0 > prefixLowestIndex
+         * if there is a match, all branches become predictable. */
+        mvalAddr = base + idx;
+        mvalAddr = ZSTD_selectAddr(idx, prefixStartIndex, mvalAddr, &dummy[0]);
+
         /* load match for ip[0] */
-        if (idx >= prefixStartIndex) {
-            mval = MEM_read32(base + idx);
-        } else {
-            mval = MEM_read32(ip0) ^ 1; /* guaranteed to not match. */
-        }
+        mval = MEM_read32(mvalAddr);
 
         /* check match at ip[0] */
-        if (MEM_read32(ip0) == mval) {
+        if (MEM_read32(ip0) == mval && idx >= prefixStartIndex) {
             /* found a match! */
 
             /* First write next hash table entry; we've already calculated it.
@@ -281,15 +289,15 @@ _start: /* Requires: ip0 */
         current0 = (U32)(ip0 - base);
         hashTable[hash0] = current0;
 
+        mvalAddr = base + idx;
+        mvalAddr = ZSTD_selectAddr(idx, prefixStartIndex, mvalAddr, &dummy[0]);
+
         /* load match for ip[0] */
-        if (idx >= prefixStartIndex) {
-            mval = MEM_read32(base + idx);
-        } else {
-            mval = MEM_read32(ip0) ^ 1; /* guaranteed to not match. */
-        }
+        mval = MEM_read32(mvalAddr);
+
 
         /* check match at ip[0] */
-        if (MEM_read32(ip0) == mval) {
+        if (MEM_read32(ip0) == mval && idx >= prefixStartIndex) {
             /* found a match! */
 
             /* first write next hash table entry; we've already calculated it */