From dbccbd17a94fb43f7938846e3e225c0c2480bb93 Mon Sep 17 00:00:00 2001 From: Icenowy Zheng Date: Sun, 15 Dec 2024 01:31:48 +0800 Subject: [PATCH] adler32_rvv: Fix some overflow problems There are currently some overflow problems in adler32_rvv implementation, which can lead to wrong results for some input, and these problems could be easily exhibited when running `git fsck` with zlib-ng suitituting the system zlib on a big git repository. These problems and the solutions are the following: - When the input data is long enough, the v_buf32_accu can overflow too. Add it to the modulo code that happens per ~NMAX bytes. - When the vector data is reduced to scalar ones, the resulting scalar value (and the proceeded length) may lead to the calculation of sum2 to overflow. Add mod BASE to all these reductions and initial calculation of sum2. - When the remaining data less than vl bytes, the code falls back to a scalar implementation; however the sum2 and alder2 values are just reduced from vectors and could be very big that makes sum2 overflows in the scalar code. Modulo them before the scalar code to prevent such overflow (because vl is surely quite smaller than NMAX). Signed-off-by: Icenowy Zheng --- arch/riscv/adler32_rvv.c | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/arch/riscv/adler32_rvv.c b/arch/riscv/adler32_rvv.c index d0f9aaa5..d822d75a 100644 --- a/arch/riscv/adler32_rvv.c +++ b/arch/riscv/adler32_rvv.c @@ -72,6 +72,7 @@ static inline uint32_t adler32_rvv_impl(uint32_t adler, uint8_t* restrict dst, c /* do modulo once each block of NMAX size */ if (++cnt >= nmax_limit) { v_adler32_prev_accu = __riscv_vremu_vx_u32m4(v_adler32_prev_accu, BASE, vl); + v_buf32_accu = __riscv_vremu_vx_u32m4(v_buf32_accu, BASE, vl); cnt = 0; } } @@ -99,16 +100,19 @@ static inline uint32_t adler32_rvv_impl(uint32_t adler, uint8_t* restrict dst, c vuint32m1_t v_sum2_sum = __riscv_vmv_s_x_u32m1(0, vl); v_sum2_sum = __riscv_vredsum_vs_u32m4_u32m1(v_sum32_accu, v_sum2_sum, vl); - uint32_t sum2_sum = __riscv_vmv_x_s_u32m1_u32(v_sum2_sum); + uint32_t sum2_sum = __riscv_vmv_x_s_u32m1_u32(v_sum2_sum) % BASE; - sum2 += (sum2_sum + adler * (len - left)); + sum2 += (sum2_sum + adler * ((len - left) % BASE)); vuint32m1_t v_adler_sum = __riscv_vmv_s_x_u32m1(0, vl); v_adler_sum = __riscv_vredsum_vs_u32m4_u32m1(v_buf32_accu, v_adler_sum, vl); - uint32_t adler_sum = __riscv_vmv_x_s_u32m1_u32(v_adler_sum); + uint32_t adler_sum = __riscv_vmv_x_s_u32m1_u32(v_adler_sum) % BASE; adler += adler_sum; + sum2 %= BASE; + adler %= BASE; + while (left--) { if (COPY) *dst++ = *src; adler += *src++; -- 2.47.2