]> git.ipfire.org Git - thirdparty/qemu.git/commitdiff
target/arm: Support FPCR.AH in SME FMOPS, BFMOPS
authorRichard Henderson <richard.henderson@linaro.org>
Fri, 4 Jul 2025 14:21:06 +0000 (08:21 -0600)
committerPeter Maydell <peter.maydell@linaro.org>
Fri, 4 Jul 2025 14:53:23 +0000 (15:53 +0100)
For non-widening, we can use float_muladd_negate_product,
For widening, which uses dot-product, we need to handle
the negation explicitly.

Reviewed-by: Peter Maydell <peter.maydell@linaro.org>
Signed-off-by: Richard Henderson <richard.henderson@linaro.org>
Message-id: 20250704142112.1018902-104-richard.henderson@linaro.org
Signed-off-by: Peter Maydell <peter.maydell@linaro.org>
target/arm/tcg/helper-sme.h
target/arm/tcg/sme_helper.c
target/arm/tcg/translate-sme.c
target/arm/tcg/vec_internal.h

index 16083660e2f94d85a61d07a8bc9af0825c6fdec4..2b22c6aee50b944502100b25adf56c9aefe53806 100644 (file)
@@ -143,6 +143,25 @@ DEF_HELPER_FLAGS_7(sme_fmopa_d, TCG_CALL_NO_RWG,
                    void, ptr, ptr, ptr, ptr, ptr, fpst, i32)
 DEF_HELPER_FLAGS_7(sme_bfmopa_w, TCG_CALL_NO_RWG,
                    void, ptr, ptr, ptr, ptr, ptr, env, i32)
+
+DEF_HELPER_FLAGS_7(sme_fmops_w_h, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, ptr, env, i32)
+DEF_HELPER_FLAGS_7(sme_fmops_s, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, ptr, fpst, i32)
+DEF_HELPER_FLAGS_7(sme_fmops_d, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, ptr, fpst, i32)
+DEF_HELPER_FLAGS_7(sme_bfmops_w, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, ptr, env, i32)
+
+DEF_HELPER_FLAGS_7(sme_ah_fmops_w_h, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, ptr, env, i32)
+DEF_HELPER_FLAGS_7(sme_ah_fmops_s, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, ptr, fpst, i32)
+DEF_HELPER_FLAGS_7(sme_ah_fmops_d, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, ptr, fpst, i32)
+DEF_HELPER_FLAGS_7(sme_ah_bfmops_w, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, ptr, env, i32)
+
 DEF_HELPER_FLAGS_6(sme_smopa_s, TCG_CALL_NO_RWG,
                    void, ptr, ptr, ptr, ptr, ptr, i32)
 DEF_HELPER_FLAGS_6(sme_umopa_s, TCG_CALL_NO_RWG,
index 4772c97debba46be0dbfbca1ecfda57b006e1a9d..eff0ce74808e2c5645db0aca748d481863cd7471 100644 (file)
@@ -1002,19 +1002,18 @@ void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn,
     }
 }
 
-void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
-                         void *vpm, float_status *fpst, uint32_t desc)
+static void do_fmopa_s(void *vza, void *vzn, void *vzm, uint16_t *pn,
+                       uint16_t *pm, float_status *fpst, uint32_t desc,
+                       uint32_t negx, int negf)
 {
     intptr_t row, col, oprsz = simd_maxsz(desc);
-    uint32_t neg = simd_data(desc) << 31;
-    uint16_t *pn = vpn, *pm = vpm;
 
     for (row = 0; row < oprsz; ) {
         uint16_t pa = pn[H2(row >> 4)];
         do {
             if (pa & 1) {
                 void *vza_row = vza + tile_vslice_offset(row);
-                uint32_t n = *(uint32_t *)(vzn + H1_4(row)) ^ neg;
+                uint32_t n = *(uint32_t *)(vzn + H1_4(row)) ^ negx;
 
                 for (col = 0; col < oprsz; ) {
                     uint16_t pb = pm[H2(col >> 4)];
@@ -1022,7 +1021,7 @@ void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
                         if (pb & 1) {
                             uint32_t *a = vza_row + H1_4(col);
                             uint32_t *m = vzm + H1_4(col);
-                            *a = float32_muladd(n, *m, *a, 0, fpst);
+                            *a = float32_muladd(n, *m, *a, negf, fpst);
                         }
                         col += 4;
                         pb >>= 4;
@@ -1035,29 +1034,65 @@ void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
     }
 }
 
-void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn,
+void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
+                         void *vpm, float_status *fpst, uint32_t desc)
+{
+    do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
+}
+
+void HELPER(sme_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn,
                          void *vpm, float_status *fpst, uint32_t desc)
+{
+    do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 31, 0);
+}
+
+void HELPER(sme_ah_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn,
+                            void *vpm, float_status *fpst, uint32_t desc)
+{
+    do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
+               float_muladd_negate_product);
+}
+
+static void do_fmopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm, uint8_t *pn,
+                       uint8_t *pm, float_status *fpst, uint32_t desc,
+                       uint64_t negx, int negf)
 {
     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
-    uint64_t neg = (uint64_t)simd_data(desc) << 63;
-    uint64_t *za = vza, *zn = vzn, *zm = vzm;
-    uint8_t *pn = vpn, *pm = vpm;
 
     for (row = 0; row < oprsz; ++row) {
         if (pn[H1(row)] & 1) {
             uint64_t *za_row = &za[tile_vslice_index(row)];
-            uint64_t n = zn[row] ^ neg;
+            uint64_t n = zn[row] ^ negx;
 
             for (col = 0; col < oprsz; ++col) {
                 if (pm[H1(col)] & 1) {
                     uint64_t *a = &za_row[col];
-                    *a = float64_muladd(n, zm[col], *a, 0, fpst);
+                    *a = float64_muladd(n, zm[col], *a, negf, fpst);
                 }
             }
         }
     }
 }
 
+void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn,
+                         void *vpm, float_status *fpst, uint32_t desc)
+{
+    do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
+}
+
+void HELPER(sme_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn,
+                         void *vpm, float_status *fpst, uint32_t desc)
+{
+    do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 1ull << 63, 0);
+}
+
+void HELPER(sme_ah_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn,
+                            void *vpm, float_status *fpst, uint32_t desc)
+{
+    do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
+               float_muladd_negate_product);
+}
+
 /*
  * Alter PAIR as needed for controlling predicates being false,
  * and for NEG on an enabled row element.
@@ -1078,6 +1113,20 @@ static inline uint32_t f16mop_adj_pair(uint32_t pair, uint32_t pg, uint32_t neg)
     return pair;
 }
 
+static inline uint32_t f16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg)
+{
+    uint32_t l = pg & 1 ? float16_ah_chs(pair) : 0;
+    uint32_t h = pg & 4 ? float16_ah_chs(pair >> 16) : 0;
+    return l | (h << 16);
+}
+
+static inline uint32_t bf16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg)
+{
+    uint32_t l = pg & 1 ? bfloat16_ah_chs(pair) : 0;
+    uint32_t h = pg & 4 ? bfloat16_ah_chs(pair >> 16) : 0;
+    return l | (h << 16);
+}
+
 static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
                           float_status *s_f16, float_status *s_std,
                           float_status *s_odd)
@@ -1146,12 +1195,11 @@ static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
     return float32_add(sum, t32, s_std);
 }
 
-void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
-                           void *vpm, CPUARMState *env, uint32_t desc)
+static void do_fmopa_w_h(void *vza, void *vzn, void *vzm, uint16_t *pn,
+                         uint16_t *pm, CPUARMState *env, uint32_t desc,
+                         uint32_t negx, bool ah_neg)
 {
     intptr_t row, col, oprsz = simd_maxsz(desc);
-    uint32_t neg = simd_data(desc) * 0x80008000u;
-    uint16_t *pn = vpn, *pm = vpm;
     float_status fpst_odd = env->vfp.fp_status[FPST_ZA];
 
     set_float_rounding_mode(float_round_to_odd, &fpst_odd);
@@ -1162,7 +1210,11 @@ void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
             void *vza_row = vza + tile_vslice_offset(row);
             uint32_t n = *(uint32_t *)(vzn + H1_4(row));
 
-            n = f16mop_adj_pair(n, prow, neg);
+            if (ah_neg) {
+                n = f16mop_ah_neg_adj_pair(n, prow);
+            } else {
+                n = f16mop_adj_pair(n, prow, negx);
+            }
 
             for (col = 0; col < oprsz; ) {
                 uint16_t pcol = pm[H2(col >> 4)];
@@ -1187,6 +1239,24 @@ void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
     }
 }
 
+void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
+                           void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, false);
+}
+
+void HELPER(sme_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
+                           void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false);
+}
+
+void HELPER(sme_ah_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
+                              void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, true);
+}
+
 void HELPER(sme2_fdot_h)(void *vd, void *vn, void *vm, void *va,
                          CPUARMState *env, uint32_t desc)
 {
@@ -1261,12 +1331,11 @@ void HELPER(sme2_fvdot_idx_h)(void *vd, void *vn, void *vm, void *va,
     }
 }
 
-void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm,
-                          void *vpn, void *vpm, CPUARMState *env, uint32_t desc)
+static void do_bfmopa_w(void *vza, void *vzn, void *vzm,
+                        uint16_t *pn, uint16_t *pm, CPUARMState *env,
+                        uint32_t desc, uint32_t negx, bool ah_neg)
 {
     intptr_t row, col, oprsz = simd_maxsz(desc);
-    uint32_t neg = simd_data(desc) * 0x80008000u;
-    uint16_t *pn = vpn, *pm = vpm;
     float_status fpst, fpst_odd;
 
     if (is_ebf(env, &fpst, &fpst_odd)) {
@@ -1276,7 +1345,11 @@ void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm,
                 void *vza_row = vza + tile_vslice_offset(row);
                 uint32_t n = *(uint32_t *)(vzn + H1_4(row));
 
-                n = f16mop_adj_pair(n, prow, neg);
+                if (ah_neg) {
+                    n = bf16mop_ah_neg_adj_pair(n, prow);
+                } else {
+                    n = f16mop_adj_pair(n, prow, negx);
+                }
 
                 for (col = 0; col < oprsz; ) {
                     uint16_t pcol = pm[H2(col >> 4)];
@@ -1303,7 +1376,11 @@ void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm,
                 void *vza_row = vza + tile_vslice_offset(row);
                 uint32_t n = *(uint32_t *)(vzn + H1_4(row));
 
-                n = f16mop_adj_pair(n, prow, neg);
+                if (ah_neg) {
+                    n = bf16mop_ah_neg_adj_pair(n, prow);
+                } else {
+                    n = f16mop_adj_pair(n, prow, negx);
+                }
 
                 for (col = 0; col < oprsz; ) {
                     uint16_t pcol = pm[H2(col >> 4)];
@@ -1326,6 +1403,24 @@ void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm,
     }
 }
 
+void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm, void *vpn,
+                          void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, false);
+}
+
+void HELPER(sme_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn,
+                          void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false);
+}
+
+void HELPER(sme_ah_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn,
+                             void *vpm, CPUARMState *env, uint32_t desc)
+{
+    do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, true);
+}
+
 typedef uint32_t IMOPFn32(uint32_t, uint32_t, uint32_t, uint8_t, bool);
 static inline void do_imopa_s(uint32_t *za, uint32_t *zn, uint32_t *zm,
                               uint8_t *pn, uint8_t *pm,
index 38d0231b0a92687c56c8842152f82fd267f9f1c9..782f4080611538657a800d65de17c48415674380 100644 (file)
@@ -526,7 +526,7 @@ static bool do_outprod_fpst(DisasContext *s, arg_op *a, MemOp esz,
                             gen_helper_gvec_5_ptr *fn)
 {
     int svl = streaming_vec_reg_size(s);
-    uint32_t desc = simd_desc(svl, svl, a->sub);
+    uint32_t desc = simd_desc(svl, svl, 0);
     TCGv_ptr za, zn, zm, pn, pm, fpst;
 
     if (!sme_smza_enabled_check(s)) {
@@ -548,7 +548,7 @@ static bool do_outprod_env(DisasContext *s, arg_op *a, MemOp esz,
                            gen_helper_gvec_5_ptr *fn)
 {
     int svl = streaming_vec_reg_size(s);
-    uint32_t desc = simd_desc(svl, svl, a->sub);
+    uint32_t desc = simd_desc(svl, svl, 0);
     TCGv_ptr za, zn, zm, pn, pm;
 
     if (!sme_smza_enabled_check(s)) {
@@ -565,14 +565,23 @@ static bool do_outprod_env(DisasContext *s, arg_op *a, MemOp esz,
     return true;
 }
 
-TRANS_FEAT(FMOPA_w_h, aa64_sme, do_outprod_env, a,
-           MO_32, gen_helper_sme_fmopa_w_h)
-TRANS_FEAT(FMOPA_s, aa64_sme, do_outprod_fpst, a,
-           MO_32, FPST_ZA, gen_helper_sme_fmopa_s)
-TRANS_FEAT(FMOPA_d, aa64_sme_f64f64, do_outprod_fpst, a,
-           MO_64, FPST_ZA, gen_helper_sme_fmopa_d)
-
-TRANS_FEAT(BFMOPA_w, aa64_sme, do_outprod_env, a, MO_32, gen_helper_sme_bfmopa_w)
+TRANS_FEAT(FMOPA_w_h, aa64_sme, do_outprod_env, a, MO_32,
+           !a->sub ? gen_helper_sme_fmopa_w_h
+           : !s->fpcr_ah ? gen_helper_sme_fmops_w_h
+           : gen_helper_sme_ah_fmops_w_h)
+TRANS_FEAT(FMOPA_s, aa64_sme, do_outprod_fpst, a, MO_32, FPST_ZA,
+           !a->sub ? gen_helper_sme_fmopa_s
+           : !s->fpcr_ah ? gen_helper_sme_fmops_s
+           : gen_helper_sme_ah_fmops_s)
+TRANS_FEAT(FMOPA_d, aa64_sme_f64f64, do_outprod_fpst, a, MO_64, FPST_ZA,
+           !a->sub ? gen_helper_sme_fmopa_d
+           : !s->fpcr_ah ? gen_helper_sme_fmops_d
+           : gen_helper_sme_ah_fmops_d)
+
+TRANS_FEAT(BFMOPA_w, aa64_sme, do_outprod_env, a, MO_32,
+           !a->sub ? gen_helper_sme_bfmopa_w
+           : !s->fpcr_ah ? gen_helper_sme_bfmops_w
+           : gen_helper_sme_ah_bfmops_w)
 
 TRANS_FEAT(SMOPA_s, aa64_sme, do_outprod, a, MO_32, gen_helper_sme_smopa_s)
 TRANS_FEAT(UMOPA_s, aa64_sme, do_outprod, a, MO_32, gen_helper_sme_umopa_s)
index 957bf6d9fcaf24417d31523cbb909240d96672ce..cf41b03dbcd52ec33ea3d467e648ad1643135a48 100644 (file)
@@ -300,6 +300,11 @@ bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp);
 /*
  * Negate as for FPCR.AH=1 -- do not negate NaNs.
  */
+static inline float16 bfloat16_ah_chs(float16 a)
+{
+    return bfloat16_is_any_nan(a) ? a : bfloat16_chs(a);
+}
+
 static inline float16 float16_ah_chs(float16 a)
 {
     return float16_is_any_nan(a) ? a : float16_chs(a);