]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
AArch64: Add SVE2 implementation of histogram computation
authorArpad Panyik <Arpad.Panyik@arm.com>
Wed, 11 Jun 2025 12:14:22 +0000 (12:14 +0000)
committerArpad Panyik <Arpad.Panyik@arm.com>
Wed, 11 Jun 2025 12:14:22 +0000 (12:14 +0000)
The existing scalar implementation uses a 4-way pipelined histogram
calculation which is very efficient on out-of-order CPUs. However,
this can be further accelerated using the SVE2 HISTSEG instructions -
which compute a histogram for 16 byte chunks in a vector register.

On a system with 128-bit vectors (VL128) we need 16 HISTSEG executions
to compute the histogram for the whole symbol space (0..255) of 16
bytes input. However we can only accumulate 15 of such 16 byte strips
before possible overflow. So we need to extend and save the 8-bit
histogram accumulators to 16-bit after every 240 byte chunks of input.
To store all in registers we would need 32 128-bit registers. Longer
SVE2 vectors could help here, if such machines become available.

The maximum input block size in Zstd is 128 KiB, so 16-bit accumulators
would not be enough. However an LZ pass will prepend the histogram
calculation, so it is impossible (my assumption) to overflow the 16-bit
accumulators.

The symbol distribution is also not uniform, the lower values are more
common, so we used a 3 pass algorithm to prevent stack spilling. In the
first pass we only compute histograms for 64 symbols (4-way SIMD) while
also computing the maximum symbol value. If we have symbol values
larger than 64 we start the second pass to compute the next 96 elements
of the histogram. The final pass calculates the remaining part of the
histogram (256 symbols in total) if needed. This split of histogram
generation gave the best overall results for performance.

This implementation is the best performing of a number of different
cache blocking schemes tested.

Compression uplifts on a Neoverse V2 system, using Zstd-1.5.8
(e26dde3d) as a baseline, compiled with "-O3 -march=armv8.2-a+sve2":

                 Clang-20    GCC-14
 1#silesia.tar:   +6.173%   +5.987%
 2#silesia.tar:   +5.200%   +5.011%
 3#silesia.tar:   +4.332%   +5.031%
 4#silesia.tar:   +2.789%   +3.064%
 5#silesia.tar:   +2.028%   +1.838%
 6#silesia.tar:   +1.562%   +1.340%
 7#silesia.tar:   +1.160%   +0.959%

lib/common/compiler.h
lib/compress/hist.c
lib/compress/hist.h

index cafb35b71e40c52372ac2d0f371f7974d809833a..6131ad0688f158a27b0bb4548d78187f49b39168 100644 (file)
 #  if defined(__ARM_NEON) || defined(_M_ARM64)
 #    define ZSTD_ARCH_ARM_NEON
 #  endif
+#  if defined(__ARM_FEATURE_SVE)
+#    define ZSTD_ARCH_ARM_SVE
+#  endif
+#  if defined(__ARM_FEATURE_SVE2)
+#    define ZSTD_ARCH_ARM_SVE2
+#  endif
 # if defined(__riscv) && defined(__riscv_vector)
 #   define ZSTD_ARCH_RISCV_RVV
 # endif
 #  elif defined(ZSTD_ARCH_ARM_NEON)
 #    include <arm_neon.h>
 #  endif
+#  if defined(ZSTD_ARCH_ARM_SVE) || defined(ZSTD_ARCH_ARM_SVE2)
+#    include <arm_sve.h>
+#  endif
 #  if defined(ZSTD_ARCH_RISCV_RVV)
 #    include <riscv_vector.h>
 #  endif
index 4ccf9a90a9ead81993135c13ee23a0a2d7cb9bf5..3692bc250cadbb831133b432094a7115efed848e 100644 (file)
 #include "../common/error_private.h"   /* ERROR */
 #include "hist.h"
 
+#if defined(ZSTD_ARCH_ARM_SVE2)
+#define HIST_FAST_THRESHOLD 500
+#else
+#define HIST_FAST_THRESHOLD 1500
+#endif
+
 
 /* --- Error management --- */
 unsigned HIST_isError(size_t code) { return ERR_isError(code); }
@@ -65,6 +71,244 @@ unsigned HIST_count_simple(unsigned* count, unsigned* maxSymbolValuePtr,
 
 typedef enum { trustInput, checkMaxSymbolValue } HIST_checkInput_e;
 
+#if defined(ZSTD_ARCH_ARM_SVE2)
+FORCE_INLINE_TEMPLATE size_t min_size(size_t a, size_t b) { return a < b ? a : b; }
+
+static
+svuint16_t HIST_count_6_sve2(const BYTE* const src, size_t size, U32* const dst,
+                             const svuint8_t c0, const svuint8_t c1,
+                             const svuint8_t c2, const svuint8_t c3,
+                             const svuint8_t c4, const svuint8_t c5,
+                             const svuint16_t histmax, size_t maxCount)
+{
+    const svbool_t vl128 = svptrue_pat_b8(SV_VL16);
+    svuint16_t hh0 = svdup_n_u16(0);
+    svuint16_t hh1 = svdup_n_u16(0);
+    svuint16_t hh2 = svdup_n_u16(0);
+    svuint16_t hh3 = svdup_n_u16(0);
+    svuint16_t hh4 = svdup_n_u16(0);
+    svuint16_t hh5 = svdup_n_u16(0);
+    svuint16_t hh6 = svdup_n_u16(0);
+    svuint16_t hh7 = svdup_n_u16(0);
+    svuint16_t hh8 = svdup_n_u16(0);
+    svuint16_t hh9 = svdup_n_u16(0);
+    svuint16_t hha = svdup_n_u16(0);
+    svuint16_t hhb = svdup_n_u16(0);
+
+    size_t i = 0;
+    while (i < size) {
+        /* We can only accumulate 15 (15 * 16 <= 255) iterations of histogram
+         * in 8-bit accumulators! */
+        const size_t size240 = min_size(i + 240, size);
+
+        svbool_t pred = svwhilelt_b8_u64(i, size);
+        svuint8_t c = svld1rq_u8(pred, src + i);
+        svuint8_t h0 = svhistseg_u8(c0, c);
+        svuint8_t h1 = svhistseg_u8(c1, c);
+        svuint8_t h2 = svhistseg_u8(c2, c);
+        svuint8_t h3 = svhistseg_u8(c3, c);
+        svuint8_t h4 = svhistseg_u8(c4, c);
+        svuint8_t h5 = svhistseg_u8(c5, c);
+
+        for (i += 16; i < size240; i += 16) {
+            pred = svwhilelt_b8_u64(i, size);
+            c = svld1rq_u8(pred, src + i);
+            h0 = svadd_u8_x(vl128, h0, svhistseg_u8(c0, c));
+            h1 = svadd_u8_x(vl128, h1, svhistseg_u8(c1, c));
+            h2 = svadd_u8_x(vl128, h2, svhistseg_u8(c2, c));
+            h3 = svadd_u8_x(vl128, h3, svhistseg_u8(c3, c));
+            h4 = svadd_u8_x(vl128, h4, svhistseg_u8(c4, c));
+            h5 = svadd_u8_x(vl128, h5, svhistseg_u8(c5, c));
+        }
+
+        hh0 = svaddwb_u16(hh0, h0);
+        hh1 = svaddwt_u16(hh1, h0);
+        hh2 = svaddwb_u16(hh2, h1);
+        hh3 = svaddwt_u16(hh3, h1);
+        hh4 = svaddwb_u16(hh4, h2);
+        hh5 = svaddwt_u16(hh5, h2);
+        hh6 = svaddwb_u16(hh6, h3);
+        hh7 = svaddwt_u16(hh7, h3);
+        hh8 = svaddwb_u16(hh8, h4);
+        hh9 = svaddwt_u16(hh9, h4);
+        hha = svaddwb_u16(hha, h5);
+        hhb = svaddwt_u16(hhb, h5);
+    }
+
+    svst1_u32(svwhilelt_b32_u64( 0, maxCount), dst +  0, svshllb_n_u32(hh0, 0));
+    svst1_u32(svwhilelt_b32_u64( 4, maxCount), dst +  4, svshllt_n_u32(hh0, 0));
+    svst1_u32(svwhilelt_b32_u64( 8, maxCount), dst +  8, svshllb_n_u32(hh1, 0));
+    svst1_u32(svwhilelt_b32_u64(12, maxCount), dst + 12, svshllt_n_u32(hh1, 0));
+    svst1_u32(svwhilelt_b32_u64(16, maxCount), dst + 16, svshllb_n_u32(hh2, 0));
+    svst1_u32(svwhilelt_b32_u64(20, maxCount), dst + 20, svshllt_n_u32(hh2, 0));
+    svst1_u32(svwhilelt_b32_u64(24, maxCount), dst + 24, svshllb_n_u32(hh3, 0));
+    svst1_u32(svwhilelt_b32_u64(28, maxCount), dst + 28, svshllt_n_u32(hh3, 0));
+    svst1_u32(svwhilelt_b32_u64(32, maxCount), dst + 32, svshllb_n_u32(hh4, 0));
+    svst1_u32(svwhilelt_b32_u64(36, maxCount), dst + 36, svshllt_n_u32(hh4, 0));
+    svst1_u32(svwhilelt_b32_u64(40, maxCount), dst + 40, svshllb_n_u32(hh5, 0));
+    svst1_u32(svwhilelt_b32_u64(44, maxCount), dst + 44, svshllt_n_u32(hh5, 0));
+    svst1_u32(svwhilelt_b32_u64(48, maxCount), dst + 48, svshllb_n_u32(hh6, 0));
+    svst1_u32(svwhilelt_b32_u64(52, maxCount), dst + 52, svshllt_n_u32(hh6, 0));
+    svst1_u32(svwhilelt_b32_u64(56, maxCount), dst + 56, svshllb_n_u32(hh7, 0));
+    svst1_u32(svwhilelt_b32_u64(60, maxCount), dst + 60, svshllt_n_u32(hh7, 0));
+    svst1_u32(svwhilelt_b32_u64(64, maxCount), dst + 64, svshllb_n_u32(hh8, 0));
+    svst1_u32(svwhilelt_b32_u64(68, maxCount), dst + 68, svshllt_n_u32(hh8, 0));
+    svst1_u32(svwhilelt_b32_u64(72, maxCount), dst + 72, svshllb_n_u32(hh9, 0));
+    svst1_u32(svwhilelt_b32_u64(76, maxCount), dst + 76, svshllt_n_u32(hh9, 0));
+    svst1_u32(svwhilelt_b32_u64(80, maxCount), dst + 80, svshllb_n_u32(hha, 0));
+    svst1_u32(svwhilelt_b32_u64(84, maxCount), dst + 84, svshllt_n_u32(hha, 0));
+    svst1_u32(svwhilelt_b32_u64(88, maxCount), dst + 88, svshllb_n_u32(hhb, 0));
+    svst1_u32(svwhilelt_b32_u64(92, maxCount), dst + 92, svshllt_n_u32(hhb, 0));
+
+    hh0 = svmax_u16_x(vl128, hh0, hh1);
+    hh2 = svmax_u16_x(vl128, hh2, hh3);
+    hh4 = svmax_u16_x(vl128, hh4, hh5);
+    hh6 = svmax_u16_x(vl128, hh6, hh7);
+    hh8 = svmax_u16_x(vl128, hh8, hh9);
+    hha = svmax_u16_x(vl128, hha, hhb);
+    hh0 = svmax_u16_x(vl128, hh0, hh2);
+    hh4 = svmax_u16_x(vl128, hh4, hh6);
+    hh8 = svmax_u16_x(vl128, hh8, hha);
+    hh0 = svmax_u16_x(vl128, hh0, hh4);
+    hh8 = svmax_u16_x(vl128, hh8, histmax);
+    return svmax_u16_x(vl128, hh0, hh8);
+}
+
+static size_t HIST_count_sve2(unsigned* count, unsigned* maxSymbolValuePtr,
+                              const void* source, size_t sourceSize,
+                              HIST_checkInput_e check)
+{
+    const BYTE* ip = (const BYTE*)source;
+    const size_t maxCount = *maxSymbolValuePtr + 1;
+
+    assert(*maxSymbolValuePtr <= 255);
+    if (!sourceSize) {
+        ZSTD_memset(count, 0, maxCount * sizeof(*count));
+        *maxSymbolValuePtr = 0;
+        return 0;
+    }
+
+    {   const svbool_t vl128 = svptrue_pat_b8(SV_VL16);
+        const svuint8_t c0 = svreinterpret_u8(svindex_u32(0x0C040800, 0x01010101));
+        const svuint8_t c1 = svadd_n_u8_x(vl128, c0, 16);
+        const svuint8_t c2 = svadd_n_u8_x(vl128, c0, 32);
+        const svuint8_t c3 = svadd_n_u8_x(vl128, c1, 32);
+
+        svuint8_t symbolMax = svdup_n_u8(0);
+        svuint16_t hh0 = svdup_n_u16(0);
+        svuint16_t hh1 = svdup_n_u16(0);
+        svuint16_t hh2 = svdup_n_u16(0);
+        svuint16_t hh3 = svdup_n_u16(0);
+        svuint16_t hh4 = svdup_n_u16(0);
+        svuint16_t hh5 = svdup_n_u16(0);
+        svuint16_t hh6 = svdup_n_u16(0);
+        svuint16_t hh7 = svdup_n_u16(0);
+        svuint16_t max;
+        size_t maxSymbolValue;
+
+        size_t i = 0;
+        while (i < sourceSize) {
+            /* We can only accumulate 15 (15 * 16 <= 255) iterations of
+             * histogram in 8-bit accumulators! */
+            const size_t size240 = min_size(i + 240, sourceSize);
+
+            svbool_t pred = svwhilelt_b8_u64(i, sourceSize);
+            svuint8_t c = svld1rq_u8(pred, ip + i);
+            svuint8_t h0 = svhistseg_u8(c0, c);
+            svuint8_t h1 = svhistseg_u8(c1, c);
+            svuint8_t h2 = svhistseg_u8(c2, c);
+            svuint8_t h3 = svhistseg_u8(c3, c);
+            symbolMax = svmax_u8_x(vl128, symbolMax, c);
+
+            for (i += 16; i < size240; i += 16) {
+                pred = svwhilelt_b8_u64(i, sourceSize);
+                c = svld1rq_u8(pred, ip + i);
+                h0 = svadd_u8_x(vl128, h0, svhistseg_u8(c0, c));
+                h1 = svadd_u8_x(vl128, h1, svhistseg_u8(c1, c));
+                h2 = svadd_u8_x(vl128, h2, svhistseg_u8(c2, c));
+                h3 = svadd_u8_x(vl128, h3, svhistseg_u8(c3, c));
+                symbolMax = svmax_u8_x(vl128, symbolMax, c);
+            }
+
+            hh0 = svaddwb_u16(hh0, h0);
+            hh1 = svaddwt_u16(hh1, h0);
+            hh2 = svaddwb_u16(hh2, h1);
+            hh3 = svaddwt_u16(hh3, h1);
+            hh4 = svaddwb_u16(hh4, h2);
+            hh5 = svaddwt_u16(hh5, h2);
+            hh6 = svaddwb_u16(hh6, h3);
+            hh7 = svaddwt_u16(hh7, h3);
+        }
+        maxSymbolValue = svmaxv_u8(vl128, symbolMax);
+
+        if (check && maxSymbolValue > *maxSymbolValuePtr) return ERROR(maxSymbolValue_tooSmall);
+        *maxSymbolValuePtr = maxSymbolValue;
+
+        /* If the buffer size is not divisible by 16, the last elements of the final
+         * vector register read will be zeros, and these elements must be subtracted
+         * from the histogram.
+         */
+        hh0 = svsub_n_u16_m(svptrue_pat_b32(SV_VL1), hh0, -sourceSize & 15);
+
+        svst1_u32(svwhilelt_b32_u64( 0, maxCount), count +  0, svshllb_n_u32(hh0, 0));
+        svst1_u32(svwhilelt_b32_u64( 4, maxCount), count +  4, svshllt_n_u32(hh0, 0));
+        svst1_u32(svwhilelt_b32_u64( 8, maxCount), count +  8, svshllb_n_u32(hh1, 0));
+        svst1_u32(svwhilelt_b32_u64(12, maxCount), count + 12, svshllt_n_u32(hh1, 0));
+        svst1_u32(svwhilelt_b32_u64(16, maxCount), count + 16, svshllb_n_u32(hh2, 0));
+        svst1_u32(svwhilelt_b32_u64(20, maxCount), count + 20, svshllt_n_u32(hh2, 0));
+        svst1_u32(svwhilelt_b32_u64(24, maxCount), count + 24, svshllb_n_u32(hh3, 0));
+        svst1_u32(svwhilelt_b32_u64(28, maxCount), count + 28, svshllt_n_u32(hh3, 0));
+        svst1_u32(svwhilelt_b32_u64(32, maxCount), count + 32, svshllb_n_u32(hh4, 0));
+        svst1_u32(svwhilelt_b32_u64(36, maxCount), count + 36, svshllt_n_u32(hh4, 0));
+        svst1_u32(svwhilelt_b32_u64(40, maxCount), count + 40, svshllb_n_u32(hh5, 0));
+        svst1_u32(svwhilelt_b32_u64(44, maxCount), count + 44, svshllt_n_u32(hh5, 0));
+        svst1_u32(svwhilelt_b32_u64(48, maxCount), count + 48, svshllb_n_u32(hh6, 0));
+        svst1_u32(svwhilelt_b32_u64(52, maxCount), count + 52, svshllt_n_u32(hh6, 0));
+        svst1_u32(svwhilelt_b32_u64(56, maxCount), count + 56, svshllb_n_u32(hh7, 0));
+        svst1_u32(svwhilelt_b32_u64(60, maxCount), count + 60, svshllt_n_u32(hh7, 0));
+
+        hh0 = svmax_u16_x(vl128, hh0, hh1);
+        hh2 = svmax_u16_x(vl128, hh2, hh3);
+        hh4 = svmax_u16_x(vl128, hh4, hh5);
+        hh6 = svmax_u16_x(vl128, hh6, hh7);
+        hh0 = svmax_u16_x(vl128, hh0, hh2);
+        hh4 = svmax_u16_x(vl128, hh4, hh6);
+        max = svmax_u16_x(vl128, hh0, hh4);
+
+        maxSymbolValue = min_size(maxSymbolValue, maxCount);
+        if (maxSymbolValue >= 64) {
+            const svuint8_t c4 = svadd_n_u8_x(vl128, c0,  64);
+            const svuint8_t c5 = svadd_n_u8_x(vl128, c1,  64);
+            const svuint8_t c6 = svadd_n_u8_x(vl128, c2,  64);
+            const svuint8_t c7 = svadd_n_u8_x(vl128, c3,  64);
+            const svuint8_t c8 = svadd_n_u8_x(vl128, c0, 128);
+            const svuint8_t c9 = svadd_n_u8_x(vl128, c1, 128);
+
+            max = HIST_count_6_sve2(ip, sourceSize, count + 64, c4, c5, c6, c7,
+                                    c8, c9, max, maxCount - 64);
+
+            if (maxSymbolValue >= 160) {
+                const svuint8_t ca = svadd_n_u8_x(vl128, c2, 128);
+                const svuint8_t cb = svadd_n_u8_x(vl128, c3, 128);
+                const svuint8_t cc = svadd_n_u8_x(vl128, c4, 128);
+                const svuint8_t cd = svadd_n_u8_x(vl128, c5, 128);
+                const svuint8_t ce = svadd_n_u8_x(vl128, c6, 128);
+                const svuint8_t cf = svadd_n_u8_x(vl128, c7, 128);
+
+                max = HIST_count_6_sve2(ip, sourceSize, count + 160, ca, cb, cc,
+                                        cd, ce, cf, max, maxCount - 160);
+            } else if (maxCount > 160) {
+                ZSTD_memset(count + 160, 0, (maxCount - 160) * sizeof(*count));
+            }
+        } else if (maxCount > 64) {
+            ZSTD_memset(count + 64, 0, (maxCount - 64) * sizeof(*count));
+        }
+
+        return svmaxv_u16(vl128, max);
+    }
+}
+#endif
+
 /* HIST_count_parallel_wksp() :
  * store histogram into 4 intermediate tables, recombined at the end.
  * this design makes better use of OoO cpus,
@@ -73,8 +317,8 @@ typedef enum { trustInput, checkMaxSymbolValue } HIST_checkInput_e;
  * `workSpace` must be a U32 table of size >= HIST_WKSP_SIZE_U32.
  * @return : largest histogram frequency,
  *           or an error code (notably when histogram's alphabet is larger than *maxSymbolValuePtr) */
-static size_t HIST_count_parallel_wksp(
-                                unsigned* count, unsigned* maxSymbolValuePtr,
+static UNUSED_ATTR
+size_t HIST_count_parallel_wksp(unsigned* count, unsigned* maxSymbolValuePtr,
                                 const void* source, size_t sourceSize,
                                 HIST_checkInput_e check,
                                 U32* const workSpace)
@@ -151,11 +395,17 @@ size_t HIST_countFast_wksp(unsigned* count, unsigned* maxSymbolValuePtr,
                           const void* source, size_t sourceSize,
                           void* workSpace, size_t workSpaceSize)
 {
-    if (sourceSize < 1500) /* heuristic threshold */
+    if (sourceSize < HIST_FAST_THRESHOLD) /* heuristic threshold */
         return HIST_count_simple(count, maxSymbolValuePtr, source, sourceSize);
+#if defined(ZSTD_ARCH_ARM_SVE2)
+    (void)workSpace;
+    (void)workSpaceSize;
+    return HIST_count_sve2(count, maxSymbolValuePtr, source, sourceSize, trustInput);
+#else
     if ((size_t)workSpace & 3) return ERROR(GENERIC);  /* must be aligned on 4-bytes boundaries */
     if (workSpaceSize < HIST_WKSP_SIZE) return ERROR(workSpace_tooSmall);
     return HIST_count_parallel_wksp(count, maxSymbolValuePtr, source, sourceSize, trustInput, (U32*)workSpace);
+#endif
 }
 
 /* HIST_count_wksp() :
@@ -165,10 +415,15 @@ size_t HIST_count_wksp(unsigned* count, unsigned* maxSymbolValuePtr,
                        const void* source, size_t sourceSize,
                        void* workSpace, size_t workSpaceSize)
 {
+#if defined(ZSTD_ARCH_ARM_SVE2)
+    if (*maxSymbolValuePtr < 255)
+        return HIST_count_sve2(count, maxSymbolValuePtr, source, sourceSize, checkMaxSymbolValue);
+#else
     if ((size_t)workSpace & 3) return ERROR(GENERIC);  /* must be aligned on 4-bytes boundaries */
     if (workSpaceSize < HIST_WKSP_SIZE) return ERROR(workSpace_tooSmall);
     if (*maxSymbolValuePtr < 255)
         return HIST_count_parallel_wksp(count, maxSymbolValuePtr, source, sourceSize, checkMaxSymbolValue, (U32*)workSpace);
+#endif
     *maxSymbolValuePtr = 255;
     return HIST_countFast_wksp(count, maxSymbolValuePtr, source, sourceSize, workSpace, workSpaceSize);
 }
index bea2a9ebff1325e47e19d5e93f794f6e87715c20..e526e9532a7d1f7760460e8ff22a68cad2c310b2 100644 (file)
@@ -35,7 +35,11 @@ unsigned HIST_isError(size_t code);  /**< tells if a return value is an error co
 
 /* --- advanced histogram functions --- */
 
+#if defined(__ARM_FEATURE_SVE2)
+#define HIST_WKSP_SIZE_U32 0
+#else
 #define HIST_WKSP_SIZE_U32 1024
+#endif
 #define HIST_WKSP_SIZE    (HIST_WKSP_SIZE_U32 * sizeof(unsigned))
 /** HIST_count_wksp() :
  *  Same as HIST_count(), but using an externally provided scratch buffer.