]> git.ipfire.org Git - thirdparty/glibc.git/commitdiff
AArch64: Optimise SVE scalar callbacks
authorJoe Ramsay <Joe.Ramsay@arm.com>
Thu, 6 Nov 2025 15:36:03 +0000 (15:36 +0000)
committerWilco Dijkstra <wilco.dijkstra@arm.com>
Thu, 6 Nov 2025 15:45:37 +0000 (15:45 +0000)
Instead of using SVE instructions to marshall special results into the
correct lane, just write the entire vector (and the predicate) to
memory, then use cheaper scalar operations.

Geomean speedup of 16% in special intervals on Neoverse with GCC 14.

Reviewed-by: Wilco Dijkstra <Wilco.Dijkstra@arm.com>
sysdeps/aarch64/fpu/sv_math.h

index 3d576df4cce2f8429505f1a466d875f0535aeb10..65d7f0ff207bd2a93fbeaf670fcfd5d602d7d677 100644 (file)
 
 #include "vecmath_config.h"
 
+#if !defined(__ARM_FEATURE_SVE_BITS) || __ARM_FEATURE_SVE_BITS == 0
+/* If not specified by -msve-vector-bits, assume maximum vector length.  */
+# define SVE_VECTOR_BYTES 256
+#else
+# define SVE_VECTOR_BYTES (__ARM_FEATURE_SVE_BITS / 8)
+#endif
+#define SVE_NUM_FLTS (SVE_VECTOR_BYTES / sizeof (float))
+#define SVE_NUM_DBLS (SVE_VECTOR_BYTES / sizeof (double))
+/* Predicate is stored as one bit per byte of VL so requires VL / 64 bytes.  */
+#define SVE_NUM_PG_BYTES (SVE_VECTOR_BYTES / sizeof (uint64_t))
+
 #define SV_NAME_F1(fun) _ZGVsMxv_##fun##f
 #define SV_NAME_D1(fun) _ZGVsMxv_##fun
 #define SV_NAME_F2(fun) _ZGVsMxvv_##fun##f
 #define SV_NAME_D2(fun) _ZGVsMxvv_##fun
 
+static inline void
+svstr_p (uint8_t *dst, svbool_t p)
+{
+  /* Predicate STR does not currently have an intrinsic.  */
+  __asm__("str %0, [%x1]\n" : : "Upa"(p), "r"(dst) : "memory");
+}
+
 /* Double precision.  */
 static inline svint64_t
 sv_s64 (int64_t x)
@@ -51,33 +69,35 @@ sv_f64 (double x)
 static inline svfloat64_t
 sv_call_f64 (double (*f) (double), svfloat64_t x, svfloat64_t y, svbool_t cmp)
 {
-  svbool_t p = svpfirst (cmp, svpfalse ());
-  while (svptest_any (cmp, p))
+  double tmp[SVE_NUM_DBLS];
+  uint8_t pg_bits[SVE_NUM_PG_BYTES];
+  svstr_p (pg_bits, cmp);
+  svst1 (svptrue_b64 (), tmp, svsel (cmp, x, y));
+
+  for (int i = 0; i < svcntd (); i++)
     {
-      double elem = svclastb_n_f64 (p, 0, x);
-      elem = (*f) (elem);
-      svfloat64_t y2 = svdup_n_f64 (elem);
-      y = svsel_f64 (p, y2, y);
-      p = svpnext_b64 (cmp, p);
+      if (pg_bits[i] & 1)
+       tmp[i] = f (tmp[i]);
     }
-  return y;
+  return svld1 (svptrue_b64 (), tmp);
 }
 
 static inline svfloat64_t
 sv_call2_f64 (double (*f) (double, double), svfloat64_t x1, svfloat64_t x2,
              svfloat64_t y, svbool_t cmp)
 {
-  svbool_t p = svpfirst (cmp, svpfalse ());
-  while (svptest_any (cmp, p))
+  double tmp1[SVE_NUM_DBLS], tmp2[SVE_NUM_DBLS];
+  uint8_t pg_bits[SVE_NUM_PG_BYTES];
+  svstr_p (pg_bits, cmp);
+  svst1 (svptrue_b64 (), tmp1, svsel (cmp, x1, y));
+  svst1 (cmp, tmp2, x2);
+
+  for (int i = 0; i < svcntd (); i++)
     {
-      double elem1 = svclastb_n_f64 (p, 0, x1);
-      double elem2 = svclastb_n_f64 (p, 0, x2);
-      double ret = (*f) (elem1, elem2);
-      svfloat64_t y2 = svdup_n_f64 (ret);
-      y = svsel_f64 (p, y2, y);
-      p = svpnext_b64 (cmp, p);
+      if (pg_bits[i] & 1)
+       tmp1[i] = f (tmp1[i], tmp2[i]);
     }
-  return y;
+  return svld1 (svptrue_b64 (), tmp1);
 }
 
 static inline svuint64_t
@@ -109,33 +129,40 @@ sv_f32 (float x)
 static inline svfloat32_t
 sv_call_f32 (float (*f) (float), svfloat32_t x, svfloat32_t y, svbool_t cmp)
 {
-  svbool_t p = svpfirst (cmp, svpfalse ());
-  while (svptest_any (cmp, p))
+  float tmp[SVE_NUM_FLTS];
+  uint8_t pg_bits[SVE_NUM_PG_BYTES];
+  svstr_p (pg_bits, cmp);
+  svst1 (svptrue_b32 (), tmp, svsel (cmp, x, y));
+
+  for (int i = 0; i < svcntd (); i++)
     {
-      float elem = svclastb_n_f32 (p, 0, x);
-      elem = f (elem);
-      svfloat32_t y2 = svdup_n_f32 (elem);
-      y = svsel_f32 (p, y2, y);
-      p = svpnext_b32 (cmp, p);
+      uint8_t p = pg_bits[i];
+      if (p & 1)
+       tmp[i * 2] = f (tmp[i * 2]);
+      if (p & (1 << 4))
+       tmp[i * 2 + 1] = f (tmp[i * 2 + 1]);
     }
-  return y;
+  return svld1 (svptrue_b32 (), tmp);
 }
 
 static inline svfloat32_t
 sv_call2_f32 (float (*f) (float, float), svfloat32_t x1, svfloat32_t x2,
              svfloat32_t y, svbool_t cmp)
 {
-  svbool_t p = svpfirst (cmp, svpfalse ());
-  while (svptest_any (cmp, p))
+  float tmp1[SVE_NUM_FLTS], tmp2[SVE_NUM_FLTS];
+  uint8_t pg_bits[SVE_NUM_PG_BYTES];
+  svstr_p (pg_bits, cmp);
+  svst1 (svptrue_b32 (), tmp1, svsel (cmp, x1, y));
+  svst1 (cmp, tmp2, x2);
+
+  for (int i = 0; i < svcntd (); i++)
     {
-      float elem1 = svclastb_n_f32 (p, 0, x1);
-      float elem2 = svclastb_n_f32 (p, 0, x2);
-      float ret = f (elem1, elem2);
-      svfloat32_t y2 = svdup_n_f32 (ret);
-      y = svsel_f32 (p, y2, y);
-      p = svpnext_b32 (cmp, p);
+      uint8_t p = pg_bits[i];
+      if (p & 1)
+       tmp1[i * 2] = f (tmp1[i * 2], tmp2[i * 2]);
+      if (p & (1 << 4))
+       tmp1[i * 2 + 1] = f (tmp1[i * 2 + 1], tmp2[i * 2 + 1]);
     }
-  return y;
+  return svld1 (svptrue_b32 (), tmp1);
 }
-
 #endif