From f4f72f9b6bcdec8b2ba20a58a241c8d9d631480c Mon Sep 17 00:00:00 2001 From: YunQiang Su Date: Mon, 26 Aug 2024 08:45:36 +0800 Subject: [PATCH] MIPS: Support vector reduc for MSA We have SHF.fmt and HADD_S/U.fmt with MSA, which can be used for vector reduc. For min/max for U8/S8, we can SHF.B W1, W0, 0xb1 # swap byte inner every half MIN.B W1, W1, W0 SHF.H W2, W1, 0xb1 # swap half inner every word MIN.B W2, W2, W1 SHF.W W3, W2, 0xb1 # swap word inner every doubleword MIN.B W4, W3, W2 SHF.W W4, W4, 0x4e # swap the two doubleword MIN.B W4, W4, W3 For plus of S8/U8, we can use HADD HADD.H W0, W0, W0 HADD.W W0, W0, W0 HADD.D W0, W0, W0 SHF.W W1, W0, 0x4e # swap the two doubleword ADDV.D W1, W1, W0 COPY_S.B T0, W1 # COPY_U.B for U8 We can do similar for S16/U16/S32/U32/S64/U64/FLOAT/DOUBLE. gcc * config/mips/mips-msa.md: (MSA_NO_HADD): we have HADD for S8/U8/S16/U16/S32/U32 only. (reduc_smin_scal_): New define pattern. (reduc_smax_scal_): Ditto. (reduc_umin_scal_): Ditto. (reduc_umax_scal_): Ditto. (reduc_plus_scal_): Ditto. (reduc_plus_scal_v4si): Ditto. (reduc_plus_scal_v8hi): Ditto. (reduc_plus_scal_v16qi): Ditto. (reduc__scal_): Ditto. * config/mips/mips-protos.h: New function mips_expand_msa_reduc. * config/mips/mips.cc: New function mips_expand_msa_reduc. * config/mips/mips.md: Define any_bitwise iterator. gcc/testsuite: * gcc.target/mips/msa-reduc.c: New tests. --- gcc/config/mips/mips-msa.md | 128 ++++++++++++++++++++++ gcc/config/mips/mips-protos.h | 1 + gcc/config/mips/mips.cc | 41 +++++++ gcc/config/mips/mips.md | 4 + gcc/testsuite/gcc.target/mips/msa-reduc.c | 119 ++++++++++++++++++++ 5 files changed, 293 insertions(+) create mode 100644 gcc/testsuite/gcc.target/mips/msa-reduc.c diff --git a/gcc/config/mips/mips-msa.md b/gcc/config/mips/mips-msa.md index 377c63f0d357..976f296402ee 100644 --- a/gcc/config/mips/mips-msa.md +++ b/gcc/config/mips/mips-msa.md @@ -125,6 +125,9 @@ ;; Only floating-point modes. (define_mode_iterator FMSA [V2DF V4SF]) +;; Only used for reduce_plus_scal: V4SI, V8HI, V16QI have HADD. +(define_mode_iterator MSA_NO_HADD [V2DF V4SF V2DI]) + ;; The attribute gives the integer vector mode with same size. (define_mode_attr VIMODE [(V2DF "V2DI") @@ -2802,3 +2805,128 @@ (set_attr "mode" "TI") (set_attr "compact_form" "never") (set_attr "branch_likely" "no")]) + + +;; Vector reduction operation +(define_expand "reduc_smin_scal_" + [(match_operand: 0 "register_operand") + (match_operand:MSA 1 "register_operand")] + "ISA_HAS_MSA" +{ + rtx tmp = gen_reg_rtx (mode); + mips_expand_msa_reduc (gen_smin3, tmp, operands[1]); + emit_insn (gen_vec_extract (operands[0], tmp, + const0_rtx)); + DONE; +}) + +(define_expand "reduc_smax_scal_" + [(match_operand: 0 "register_operand") + (match_operand:MSA 1 "register_operand")] + "ISA_HAS_MSA" +{ + rtx tmp = gen_reg_rtx (mode); + mips_expand_msa_reduc (gen_smax3, tmp, operands[1]); + emit_insn (gen_vec_extract (operands[0], tmp, + const0_rtx)); + DONE; +}) + +(define_expand "reduc_umin_scal_" + [(match_operand: 0 "register_operand") + (match_operand:IMSA 1 "register_operand")] + "ISA_HAS_MSA" +{ + rtx tmp = gen_reg_rtx (mode); + mips_expand_msa_reduc (gen_umin3, tmp, operands[1]); + emit_insn (gen_vec_extract (operands[0], tmp, + const0_rtx)); + DONE; +}) + +(define_expand "reduc_umax_scal_" + [(match_operand: 0 "register_operand") + (match_operand:IMSA 1 "register_operand")] + "ISA_HAS_MSA" +{ + rtx tmp = gen_reg_rtx (mode); + mips_expand_msa_reduc (gen_umax3, tmp, operands[1]); + emit_insn (gen_vec_extract (operands[0], tmp, + const0_rtx)); + DONE; +}) + +(define_expand "reduc_plus_scal_" + [(match_operand: 0 "register_operand") + (match_operand:MSA_NO_HADD 1 "register_operand")] + "ISA_HAS_MSA" +{ + rtx tmp = gen_reg_rtx (mode); + mips_expand_msa_reduc (gen_add3, tmp, operands[1]); + emit_insn (gen_vec_extract (operands[0], tmp, + const0_rtx)); + DONE; +}) + +(define_expand "reduc_plus_scal_v4si" + [(match_operand:SI 0 "register_operand") + (match_operand:V4SI 1 "register_operand")] + "ISA_HAS_MSA" +{ + rtx tmp = gen_reg_rtx (SImode); + rtx tmp1 = gen_reg_rtx (V2DImode); + emit_insn (gen_msa_hadd_s_d (tmp1, operands[1], operands[1])); + emit_insn (gen_vec_extractv4sisi (operands[0], gen_lowpart (V4SImode, tmp1), + const0_rtx)); + emit_insn (gen_vec_extractv4sisi (tmp, gen_lowpart (V4SImode, tmp1), + GEN_INT (2))); + emit_insn (gen_addsi3 (operands[0], operands[0], tmp)); + DONE; +}) + +(define_expand "reduc_plus_scal_v8hi" + [(match_operand:HI 0 "register_operand") + (match_operand:V8HI 1 "register_operand")] + "ISA_HAS_MSA" +{ + rtx tmp1 = gen_reg_rtx (V4SImode); + rtx tmp2 = gen_reg_rtx (V2DImode); + rtx tmp3 = gen_reg_rtx (V2DImode); + emit_insn (gen_msa_hadd_s_w (tmp1, operands[1], operands[1])); + emit_insn (gen_msa_hadd_s_d (tmp2, tmp1, tmp1)); + mips_expand_msa_reduc (gen_addv2di3, tmp3, tmp2); + emit_insn (gen_vec_extractv8hihi (operands[0], gen_lowpart (V8HImode, tmp3), + const0_rtx)); + DONE; +}) + +(define_expand "reduc_plus_scal_v16qi" + [(match_operand:QI 0 "register_operand") + (match_operand:V16QI 1 "register_operand")] + "ISA_HAS_MSA" +{ + rtx tmp1 = gen_reg_rtx (V8HImode); + rtx tmp2 = gen_reg_rtx (V4SImode); + rtx tmp3 = gen_reg_rtx (V2DImode); + rtx tmp4 = gen_reg_rtx (V2DImode); + emit_insn (gen_msa_hadd_s_h (tmp1, operands[1], operands[1])); + emit_insn (gen_msa_hadd_s_w (tmp2, tmp1, tmp1)); + emit_insn (gen_msa_hadd_s_d (tmp3, tmp2, tmp2)); + mips_expand_msa_reduc (gen_addv2di3, tmp4, tmp3); + emit_insn (gen_vec_extractv16qiqi (operands[0], gen_lowpart (V16QImode, tmp4), + const0_rtx)); + DONE; +}) + +(define_expand "reduc__scal_" + [(any_bitwise: + (match_operand: 0 "register_operand") + (match_operand:IMSA 1 "register_operand"))] + "ISA_HAS_MSA" +{ + rtx tmp = gen_reg_rtx (mode); + mips_expand_msa_reduc (gen_3, tmp, operands[1]); + emit_insn (gen_vec_extract (operands[0], tmp, + const0_rtx)); + DONE; +}) diff --git a/gcc/config/mips/mips-protos.h b/gcc/config/mips/mips-protos.h index 90b4c87fdea1..96e084e6e641 100644 --- a/gcc/config/mips/mips-protos.h +++ b/gcc/config/mips/mips-protos.h @@ -352,6 +352,7 @@ extern void mips_expand_atomic_qihi (union mips_gen_fn_ptrs, extern void mips_expand_vector_init (rtx, rtx); extern void mips_expand_vec_unpack (rtx op[2], bool, bool); extern void mips_expand_vec_reduc (rtx, rtx, rtx (*)(rtx, rtx, rtx)); +extern void mips_expand_msa_reduc (rtx (*)(rtx, rtx, rtx), rtx, rtx); extern void mips_expand_vec_minmax (rtx, rtx, rtx, rtx (*) (rtx, rtx, rtx), bool); diff --git a/gcc/config/mips/mips.cc b/gcc/config/mips/mips.cc index 6c797b621643..173f792bf55a 100644 --- a/gcc/config/mips/mips.cc +++ b/gcc/config/mips/mips.cc @@ -22239,6 +22239,47 @@ mips_vectorize_vec_perm_const (machine_mode vmode, machine_mode op_mode, return ok; } +/* Expand a vector reduction. FN is the binary pattern to reduce; + DEST is the destination; IN is the input vector. */ + +void +mips_expand_msa_reduc (rtx (*fn) (rtx, rtx, rtx), rtx dest, rtx in) +{ + rtx swap, vec = in; + machine_mode mode = GET_MODE (in); + unsigned int i, gelt; + const unsigned nelt = GET_MODE_BITSIZE (mode) / GET_MODE_UNIT_BITSIZE (mode); + unsigned char perm[MAX_VECT_LEN]; + + /* We have no SHF.d. */ + if (nelt == 2) + { + perm[0] = 2; + perm[1] = 3; + perm[2] = 0; + perm[3] = 1; + rtx rsi = simplify_gen_subreg (V4SImode, in, mode, 0); + swap = gen_reg_rtx (V4SImode); + mips_expand_vselect (swap, rsi, perm, 4); + emit_move_insn (dest, gen_rtx_SUBREG (mode, swap, 0)); + emit_insn (fn (dest, dest, vec)); + return; + } + + for (gelt=1; gelt<=nelt/2; gelt *= 2) + { + for (i = 0; i + +#define D_TY_CALC(type) \ + type a_##type[32] __attribute__ ((aligned (16))); \ + type min_##type () { \ + type ret = a_##type[0]; \ + for (int i=0; i<32; i++) \ + ret = (ret < a_##type[i]) ? ret : a_##type[i]; \ + return ret; \ + } \ + type max_##type () { \ + type ret = a_##type[0]; \ + for (int i=0; i<32; i++) \ + ret = (ret > a_##type[i]) ? ret : a_##type[i]; \ + return ret; \ + } \ + type plus_##type () { \ + type ret = 0; \ + for (int i=0; i<32; i++) \ + ret += a_##type[i]; \ + return ret; \ + } + +#define D_TY_BIT(type) \ + type or_##type () { \ + type ret = 0; \ + for (int i=0; i<32; i++) \ + ret |= a_##type[i]; \ + return ret; \ + } \ + type and_##type () { \ + type ret = (type)(long long)~0LL; \ + for (int i=0; i<32; i++) \ + ret &= a_##type[i]; \ + return ret; \ + } \ + type xor_##type () { \ + type ret = (type)(long long)~0LL; \ + for (int i=0; i<32; i++) \ + ret ^= a_##type[i]; \ + return ret; \ + } + +#define D_TY(type) D_TY_CALC(type) D_TY_BIT(type) + +D_TY (int8_t) +D_TY (uint8_t) +D_TY (int16_t) +D_TY (uint16_t) +D_TY (int32_t) +D_TY (uint32_t) +D_TY (int64_t) +D_TY (uint64_t) +D_TY_CALC (float) +D_TY_CALC (double) + + -- 2.47.2