From 5b82fb18827e962af9f080fdf3c1a69802783f67 Mon Sep 17 00:00:00 2001 From: Joe Ramsay Date: Thu, 6 Nov 2025 15:36:03 +0000 Subject: [PATCH] AArch64: Optimise SVE scalar callbacks Instead of using SVE instructions to marshall special results into the correct lane, just write the entire vector (and the predicate) to memory, then use cheaper scalar operations. Geomean speedup of 16% in special intervals on Neoverse with GCC 14. Reviewed-by: Wilco Dijkstra --- sysdeps/aarch64/fpu/sv_math.h | 97 ++++++++++++++++++++++------------- 1 file changed, 62 insertions(+), 35 deletions(-) diff --git a/sysdeps/aarch64/fpu/sv_math.h b/sysdeps/aarch64/fpu/sv_math.h index 3d576df4cc..65d7f0ff20 100644 --- a/sysdeps/aarch64/fpu/sv_math.h +++ b/sysdeps/aarch64/fpu/sv_math.h @@ -24,11 +24,29 @@ #include "vecmath_config.h" +#if !defined(__ARM_FEATURE_SVE_BITS) || __ARM_FEATURE_SVE_BITS == 0 +/* If not specified by -msve-vector-bits, assume maximum vector length. */ +# define SVE_VECTOR_BYTES 256 +#else +# define SVE_VECTOR_BYTES (__ARM_FEATURE_SVE_BITS / 8) +#endif +#define SVE_NUM_FLTS (SVE_VECTOR_BYTES / sizeof (float)) +#define SVE_NUM_DBLS (SVE_VECTOR_BYTES / sizeof (double)) +/* Predicate is stored as one bit per byte of VL so requires VL / 64 bytes. */ +#define SVE_NUM_PG_BYTES (SVE_VECTOR_BYTES / sizeof (uint64_t)) + #define SV_NAME_F1(fun) _ZGVsMxv_##fun##f #define SV_NAME_D1(fun) _ZGVsMxv_##fun #define SV_NAME_F2(fun) _ZGVsMxvv_##fun##f #define SV_NAME_D2(fun) _ZGVsMxvv_##fun +static inline void +svstr_p (uint8_t *dst, svbool_t p) +{ + /* Predicate STR does not currently have an intrinsic. */ + __asm__("str %0, [%x1]\n" : : "Upa"(p), "r"(dst) : "memory"); +} + /* Double precision. */ static inline svint64_t sv_s64 (int64_t x) @@ -51,33 +69,35 @@ sv_f64 (double x) static inline svfloat64_t sv_call_f64 (double (*f) (double), svfloat64_t x, svfloat64_t y, svbool_t cmp) { - svbool_t p = svpfirst (cmp, svpfalse ()); - while (svptest_any (cmp, p)) + double tmp[SVE_NUM_DBLS]; + uint8_t pg_bits[SVE_NUM_PG_BYTES]; + svstr_p (pg_bits, cmp); + svst1 (svptrue_b64 (), tmp, svsel (cmp, x, y)); + + for (int i = 0; i < svcntd (); i++) { - double elem = svclastb_n_f64 (p, 0, x); - elem = (*f) (elem); - svfloat64_t y2 = svdup_n_f64 (elem); - y = svsel_f64 (p, y2, y); - p = svpnext_b64 (cmp, p); + if (pg_bits[i] & 1) + tmp[i] = f (tmp[i]); } - return y; + return svld1 (svptrue_b64 (), tmp); } static inline svfloat64_t sv_call2_f64 (double (*f) (double, double), svfloat64_t x1, svfloat64_t x2, svfloat64_t y, svbool_t cmp) { - svbool_t p = svpfirst (cmp, svpfalse ()); - while (svptest_any (cmp, p)) + double tmp1[SVE_NUM_DBLS], tmp2[SVE_NUM_DBLS]; + uint8_t pg_bits[SVE_NUM_PG_BYTES]; + svstr_p (pg_bits, cmp); + svst1 (svptrue_b64 (), tmp1, svsel (cmp, x1, y)); + svst1 (cmp, tmp2, x2); + + for (int i = 0; i < svcntd (); i++) { - double elem1 = svclastb_n_f64 (p, 0, x1); - double elem2 = svclastb_n_f64 (p, 0, x2); - double ret = (*f) (elem1, elem2); - svfloat64_t y2 = svdup_n_f64 (ret); - y = svsel_f64 (p, y2, y); - p = svpnext_b64 (cmp, p); + if (pg_bits[i] & 1) + tmp1[i] = f (tmp1[i], tmp2[i]); } - return y; + return svld1 (svptrue_b64 (), tmp1); } static inline svuint64_t @@ -109,33 +129,40 @@ sv_f32 (float x) static inline svfloat32_t sv_call_f32 (float (*f) (float), svfloat32_t x, svfloat32_t y, svbool_t cmp) { - svbool_t p = svpfirst (cmp, svpfalse ()); - while (svptest_any (cmp, p)) + float tmp[SVE_NUM_FLTS]; + uint8_t pg_bits[SVE_NUM_PG_BYTES]; + svstr_p (pg_bits, cmp); + svst1 (svptrue_b32 (), tmp, svsel (cmp, x, y)); + + for (int i = 0; i < svcntd (); i++) { - float elem = svclastb_n_f32 (p, 0, x); - elem = f (elem); - svfloat32_t y2 = svdup_n_f32 (elem); - y = svsel_f32 (p, y2, y); - p = svpnext_b32 (cmp, p); + uint8_t p = pg_bits[i]; + if (p & 1) + tmp[i * 2] = f (tmp[i * 2]); + if (p & (1 << 4)) + tmp[i * 2 + 1] = f (tmp[i * 2 + 1]); } - return y; + return svld1 (svptrue_b32 (), tmp); } static inline svfloat32_t sv_call2_f32 (float (*f) (float, float), svfloat32_t x1, svfloat32_t x2, svfloat32_t y, svbool_t cmp) { - svbool_t p = svpfirst (cmp, svpfalse ()); - while (svptest_any (cmp, p)) + float tmp1[SVE_NUM_FLTS], tmp2[SVE_NUM_FLTS]; + uint8_t pg_bits[SVE_NUM_PG_BYTES]; + svstr_p (pg_bits, cmp); + svst1 (svptrue_b32 (), tmp1, svsel (cmp, x1, y)); + svst1 (cmp, tmp2, x2); + + for (int i = 0; i < svcntd (); i++) { - float elem1 = svclastb_n_f32 (p, 0, x1); - float elem2 = svclastb_n_f32 (p, 0, x2); - float ret = f (elem1, elem2); - svfloat32_t y2 = svdup_n_f32 (ret); - y = svsel_f32 (p, y2, y); - p = svpnext_b32 (cmp, p); + uint8_t p = pg_bits[i]; + if (p & 1) + tmp1[i * 2] = f (tmp1[i * 2], tmp2[i * 2]); + if (p & (1 << 4)) + tmp1[i * 2 + 1] = f (tmp1[i * 2 + 1], tmp2[i * 2 + 1]); } - return y; + return svld1 (svptrue_b32 (), tmp1); } - #endif -- 2.47.3