#include "vecmath_config.h"
+#if !defined(__ARM_FEATURE_SVE_BITS) || __ARM_FEATURE_SVE_BITS == 0
+/* If not specified by -msve-vector-bits, assume maximum vector length. */
+# define SVE_VECTOR_BYTES 256
+#else
+# define SVE_VECTOR_BYTES (__ARM_FEATURE_SVE_BITS / 8)
+#endif
+#define SVE_NUM_FLTS (SVE_VECTOR_BYTES / sizeof (float))
+#define SVE_NUM_DBLS (SVE_VECTOR_BYTES / sizeof (double))
+/* Predicate is stored as one bit per byte of VL so requires VL / 64 bytes. */
+#define SVE_NUM_PG_BYTES (SVE_VECTOR_BYTES / sizeof (uint64_t))
+
#define SV_NAME_F1(fun) _ZGVsMxv_##fun##f
#define SV_NAME_D1(fun) _ZGVsMxv_##fun
#define SV_NAME_F2(fun) _ZGVsMxvv_##fun##f
#define SV_NAME_D2(fun) _ZGVsMxvv_##fun
+static inline void
+svstr_p (uint8_t *dst, svbool_t p)
+{
+ /* Predicate STR does not currently have an intrinsic. */
+ __asm__("str %0, [%x1]\n" : : "Upa"(p), "r"(dst) : "memory");
+}
+
/* Double precision. */
static inline svint64_t
sv_s64 (int64_t x)
static inline svfloat64_t
sv_call_f64 (double (*f) (double), svfloat64_t x, svfloat64_t y, svbool_t cmp)
{
- svbool_t p = svpfirst (cmp, svpfalse ());
- while (svptest_any (cmp, p))
+ double tmp[SVE_NUM_DBLS];
+ uint8_t pg_bits[SVE_NUM_PG_BYTES];
+ svstr_p (pg_bits, cmp);
+ svst1 (svptrue_b64 (), tmp, svsel (cmp, x, y));
+
+ for (int i = 0; i < svcntd (); i++)
{
- double elem = svclastb_n_f64 (p, 0, x);
- elem = (*f) (elem);
- svfloat64_t y2 = svdup_n_f64 (elem);
- y = svsel_f64 (p, y2, y);
- p = svpnext_b64 (cmp, p);
+ if (pg_bits[i] & 1)
+ tmp[i] = f (tmp[i]);
}
- return y;
+ return svld1 (svptrue_b64 (), tmp);
}
static inline svfloat64_t
sv_call2_f64 (double (*f) (double, double), svfloat64_t x1, svfloat64_t x2,
svfloat64_t y, svbool_t cmp)
{
- svbool_t p = svpfirst (cmp, svpfalse ());
- while (svptest_any (cmp, p))
+ double tmp1[SVE_NUM_DBLS], tmp2[SVE_NUM_DBLS];
+ uint8_t pg_bits[SVE_NUM_PG_BYTES];
+ svstr_p (pg_bits, cmp);
+ svst1 (svptrue_b64 (), tmp1, svsel (cmp, x1, y));
+ svst1 (cmp, tmp2, x2);
+
+ for (int i = 0; i < svcntd (); i++)
{
- double elem1 = svclastb_n_f64 (p, 0, x1);
- double elem2 = svclastb_n_f64 (p, 0, x2);
- double ret = (*f) (elem1, elem2);
- svfloat64_t y2 = svdup_n_f64 (ret);
- y = svsel_f64 (p, y2, y);
- p = svpnext_b64 (cmp, p);
+ if (pg_bits[i] & 1)
+ tmp1[i] = f (tmp1[i], tmp2[i]);
}
- return y;
+ return svld1 (svptrue_b64 (), tmp1);
}
static inline svuint64_t
static inline svfloat32_t
sv_call_f32 (float (*f) (float), svfloat32_t x, svfloat32_t y, svbool_t cmp)
{
- svbool_t p = svpfirst (cmp, svpfalse ());
- while (svptest_any (cmp, p))
+ float tmp[SVE_NUM_FLTS];
+ uint8_t pg_bits[SVE_NUM_PG_BYTES];
+ svstr_p (pg_bits, cmp);
+ svst1 (svptrue_b32 (), tmp, svsel (cmp, x, y));
+
+ for (int i = 0; i < svcntd (); i++)
{
- float elem = svclastb_n_f32 (p, 0, x);
- elem = f (elem);
- svfloat32_t y2 = svdup_n_f32 (elem);
- y = svsel_f32 (p, y2, y);
- p = svpnext_b32 (cmp, p);
+ uint8_t p = pg_bits[i];
+ if (p & 1)
+ tmp[i * 2] = f (tmp[i * 2]);
+ if (p & (1 << 4))
+ tmp[i * 2 + 1] = f (tmp[i * 2 + 1]);
}
- return y;
+ return svld1 (svptrue_b32 (), tmp);
}
static inline svfloat32_t
sv_call2_f32 (float (*f) (float, float), svfloat32_t x1, svfloat32_t x2,
svfloat32_t y, svbool_t cmp)
{
- svbool_t p = svpfirst (cmp, svpfalse ());
- while (svptest_any (cmp, p))
+ float tmp1[SVE_NUM_FLTS], tmp2[SVE_NUM_FLTS];
+ uint8_t pg_bits[SVE_NUM_PG_BYTES];
+ svstr_p (pg_bits, cmp);
+ svst1 (svptrue_b32 (), tmp1, svsel (cmp, x1, y));
+ svst1 (cmp, tmp2, x2);
+
+ for (int i = 0; i < svcntd (); i++)
{
- float elem1 = svclastb_n_f32 (p, 0, x1);
- float elem2 = svclastb_n_f32 (p, 0, x2);
- float ret = f (elem1, elem2);
- svfloat32_t y2 = svdup_n_f32 (ret);
- y = svsel_f32 (p, y2, y);
- p = svpnext_b32 (cmp, p);
+ uint8_t p = pg_bits[i];
+ if (p & 1)
+ tmp1[i * 2] = f (tmp1[i * 2], tmp2[i * 2]);
+ if (p & (1 << 4))
+ tmp1[i * 2 + 1] = f (tmp1[i * 2 + 1], tmp2[i * 2 + 1]);
}
- return y;
+ return svld1 (svptrue_b32 (), tmp1);
}
-
#endif