]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
aarch64: Fix aarch64 backend-use of (u|s|us)dot_prod patterns
authorVictor Do Nascimento <victor.donascimento@arm.com>
Tue, 21 May 2024 16:13:03 +0000 (17:13 +0100)
committerVictor Do Nascimento <victor.donascimento@arm.com>
Mon, 30 Sep 2024 14:59:43 +0000 (15:59 +0100)
Given recent changes to the dot_prod standard pattern name, this patch
fixes the aarch64 back-end by implementing the following changes:

1. Add 2nd mode to all (u|s|us)dot_prod patterns in .md files.
2. Rewrite initialization and function expansion mechanism for simd
builtins.
3. Fix all direct calls to back-end `dot_prod' patterns in SVE
builtins.

Finally, given that it is now possible for the compiler to
differentiate between the two- and four-way dot product, we add a test
to ensure that autovectorization picks up on dot-product patterns
where the result is twice the width of the operands.

gcc/ChangeLog:

* config/aarch64/aarch64-simd.md
(<sur>dot_prod<vsi2qi><vczle><vczbe>): Renamed to...
(<sur>dot_prod<mode><vsi2qi><vczle><vczbe>): ...this.
(usdot_prod<vsi2qi><vczle><vczbe>): Renamed to...
(usdot_prod<mode><vsi2qi><vczle><vczbe>): ...this.
(<su>sadv16qi): Adjust call to gen_udot_prod take second mode.
(popcount<mode2>): fix use of `udot_prod_optab'.
* config/aarch64/aarch64-sve.md
(<sur>dot_prod<vsi2qi>): Renamed to...
(<sur>dot_prod<mode><vsi2qi>): ...this.
(@<sur>dot_prod<vsi2qi>): Renamed to...
(@<sur>dot_prod<mode><vsi2qi>): ...this.
(<su>sad<vsi2qi>): Adjust call to gen_udot_prod take second mode.
* config/aarch64/aarch64-sve2.md
(@aarch64_sve_<sur>dotvnx4sivnx8hi): Renamed to...
(<sur>dot_prodvnx4sivnx8hi): ...this.
* config/aarch64/aarch64-simd-builtins.def: Modify macro
expansion-based initialization and expansion
of (u|s|us)dot_prod builtins.
* config/aarch64/aarch64-builtins.cc
(CODE_FOR_aarch64_sdot_prodv8qi): Define as alias to
new CODE_FOR_sdot_prodv2siv8qi.
(CODE_FOR_aarch64_udot_prodv8qi): Define as alias to
new CODE_FOR_udot_prodv2siv8qi.
(CODE_FOR_aarch64_usdot_prodv8qi): Define as alias to
new CODE_FOR_usdot_prodv2siv8qi.
(CODE_FOR_aarch64_sdot_prodv16qi): Define as alias to
new CODE_FOR_sdot_prodv4siv16qi.
(CODE_FOR_aarch64_udot_prodv16qi): Define as alias to
new CODE_FOR_udot_prodv4siv16qi.
(CODE_FOR_aarch64_usdot_prodv16qi): Define as alias to
new CODE_FOR_usdot_prodv4siv16qi.
* config/aarch64/aarch64-sve-builtins-base.cc
(svdot_impl::expand): s/direct/convert/ in
`convert_optab_handler_for_sign' function call.
(svusdot_impl::expand): add second mode argument in call to
`code_for_dot_prod'.
* config/aarch64/aarch64-sve-builtins.cc
(function_expander::convert_optab_handler_for_sign): New class
method.
* config/aarch64/aarch64-sve-builtins.h
(class function_expander): Add prototype for new
`convert_optab_handler_for_sign' method.

gcc/testsuite/ChangeLog:
* gcc.target/aarch64/sme/vect-dotprod-twoway.c (udot2): New.

gcc/config/aarch64/aarch64-builtins.cc
gcc/config/aarch64/aarch64-simd-builtins.def
gcc/config/aarch64/aarch64-simd.md
gcc/config/aarch64/aarch64-sve-builtins-base.cc
gcc/config/aarch64/aarch64-sve-builtins.cc
gcc/config/aarch64/aarch64-sve-builtins.h
gcc/config/aarch64/aarch64-sve.md
gcc/config/aarch64/aarch64-sve2.md
gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c [new file with mode: 0644]

index 6266bea3b39c57d0077defec2aac08cd11160551..38b860c176a4daecdec698e218779432c72a8cf9 100644 (file)
@@ -458,6 +458,19 @@ aarch64_types_storestruct_lane_p_qualifiers[SIMD_MAX_BUILTIN_ARGS]
       qualifier_poly, qualifier_struct_load_store_lane_index };
 #define TYPES_STORESTRUCT_LANE_P (aarch64_types_storestruct_lane_p_qualifiers)
 
+constexpr insn_code CODE_FOR_aarch64_sdot_prodv8qi
+  = CODE_FOR_sdot_prodv2siv8qi;
+constexpr insn_code CODE_FOR_aarch64_udot_prodv8qi
+  = CODE_FOR_udot_prodv2siv8qi;
+constexpr insn_code CODE_FOR_aarch64_usdot_prodv8qi
+  = CODE_FOR_usdot_prodv2siv8qi;
+constexpr insn_code CODE_FOR_aarch64_sdot_prodv16qi
+  = CODE_FOR_sdot_prodv4siv16qi;
+constexpr insn_code CODE_FOR_aarch64_udot_prodv16qi
+  = CODE_FOR_udot_prodv4siv16qi;
+constexpr insn_code CODE_FOR_aarch64_usdot_prodv16qi
+  = CODE_FOR_usdot_prodv4siv16qi;
+
 #define CF0(N, X) CODE_FOR_aarch64_##N##X
 #define CF1(N, X) CODE_FOR_##N##X##1
 #define CF2(N, X) CODE_FOR_##N##X##2
index e65f73d7ba23fe336ef64624c8eebb00415c5b5c..0814f8ba14f52ceaa7a00fcd2e45550238412253 100644 (file)
   BUILTIN_VSDQ_I_DI (BINOP_UUS, urshl, 0, NONE)
 
   /* Implemented by <sur><dotprod>_prod<dot_mode>.  */
-  BUILTIN_VB (TERNOP, sdot_prod, 10, NONE)
-  BUILTIN_VB (TERNOPU, udot_prod, 10, NONE)
-  BUILTIN_VB (TERNOP_SUSS, usdot_prod, 10, NONE)
+  BUILTIN_VB (TERNOP, sdot_prod, 0, NONE)
+  BUILTIN_VB (TERNOPU, udot_prod, 0, NONE)
+  BUILTIN_VB (TERNOP_SUSS, usdot_prod, 0, NONE)
   /* Implemented by aarch64_<sur><dotprod>_lane{q}<dot_mode>.  */
   BUILTIN_VB (QUADOP_LANE, sdot_lane, 0, NONE)
   BUILTIN_VB (QUADOPU_LANE, udot_lane, 0, NONE)
index 2a44aa3fcc3321f410c9c68d862a6fcd4666fe49..11d405ed640f7937f985c4bae43ecd634a096604 100644 (file)
 ;; ...
 ;;
 ;; and so the vectorizer provides r, in which the result has to be accumulated.
-(define_insn "<sur>dot_prod<vsi2qi><vczle><vczbe>"
+(define_insn "<sur>dot_prod<mode><vsi2qi><vczle><vczbe>"
   [(set (match_operand:VS 0 "register_operand" "=w")
        (plus:VS
          (unspec:VS [(match_operand:<VSI2QI> 1 "register_operand" "w")
 
 ;; These instructions map to the __builtins for the Armv8.6-a I8MM usdot
 ;; (vector) Dot Product operation and the vectorized optab.
-(define_insn "usdot_prod<vsi2qi><vczle><vczbe>"
+(define_insn "usdot_prod<mode><vsi2qi><vczle><vczbe>"
   [(set (match_operand:VS 0 "register_operand" "=w")
        (plus:VS
          (unspec:VS [(match_operand:<VSI2QI> 1 "register_operand" "w")
        rtx ones = force_reg (V16QImode, CONST1_RTX (V16QImode));
        rtx abd = gen_reg_rtx (V16QImode);
        emit_insn (gen_aarch64_<su>abdv16qi (abd, operands[1], operands[2]));
-       emit_insn (gen_udot_prodv16qi (operands[0], abd, ones, operands[3]));
+       emit_insn (gen_udot_prodv4siv16qi (operands[0], abd, ones,
+                                          operands[3]));
        DONE;
       }
     rtx reduc = gen_reg_rtx (V8HImode);
 
     /* Generate a byte popcount.  */
     machine_mode mode = <bitsize> == 64 ? V8QImode : V16QImode;
+    machine_mode mode2 = <bitsize> == 64 ? V2SImode : V4SImode;
     rtx tmp = gen_reg_rtx (mode);
     auto icode = optab_handler (popcount_optab, mode);
     emit_insn (GEN_FCN (icode) (tmp, gen_lowpart (mode, operands[1])));
        /* For V4SI and V2SI, we can generate a UDOT with a 0 accumulator and a
           1 multiplicand.  For V2DI, another UAADDLP is needed.  */
        rtx ones = force_reg (mode, CONST1_RTX (mode));
-       auto icode = optab_handler (udot_prod_optab, mode);
+       auto icode = convert_optab_handler (udot_prod_optab, mode2, mode);
        mode = <bitsize> == 64 ? V2SImode : V4SImode;
        rtx dest = mode == <MODE>mode ? operands[0] : gen_reg_rtx (mode);
        rtx zeros = force_reg (mode, CONST0_RTX (mode));
index afce52a7e8dd009c2d748af06490d0ed6ddd2bcd..4b33585d98145770195ede58ad5c16d5210cd414 100644 (file)
@@ -820,15 +820,16 @@ public:
     e.rotate_inputs_left (0, 3);
     insn_code icode;
     if (e.type_suffix_ids[1] == NUM_TYPE_SUFFIXES)
-      icode = e.direct_optab_handler_for_sign (sdot_prod_optab,
-                                              udot_prod_optab,
-                                              0, GET_MODE (e.args[0]));
+      icode = e.convert_optab_handler_for_sign (sdot_prod_optab,
+                                               udot_prod_optab,
+                                               0, e.result_mode (),
+                                               GET_MODE (e.args[0]));
     else
       icode = (e.type_suffix (0).float_p
               ? CODE_FOR_aarch64_sve_fdotvnx4sfvnx8hf
               : e.type_suffix (0).unsigned_p
-              ? CODE_FOR_aarch64_sve_udotvnx4sivnx8hi
-              : CODE_FOR_aarch64_sve_sdotvnx4sivnx8hi);
+              ? CODE_FOR_udot_prodvnx4sivnx8hi
+              : CODE_FOR_sdot_prodvnx4sivnx8hi);
     return e.use_unpred_insn (icode);
   }
 };
@@ -2905,7 +2906,7 @@ public:
        Hence we do the same rotation on arguments as svdot_impl does.  */
     e.rotate_inputs_left (0, 3);
     machine_mode mode = e.vector_mode (0);
-    insn_code icode = code_for_dot_prod (UNSPEC_USDOT, mode);
+    insn_code icode = code_for_dot_prod (UNSPEC_USDOT, e.result_mode (), mode);
     return e.use_exact_insn (icode);
   }
 
index 8f9aa3cf1207918199be785299590a11ff59e18e..5ff46212d18d0c8dc99e289c89ad8a9593a6662a 100644 (file)
@@ -3690,6 +3690,21 @@ function_expander::direct_optab_handler_for_sign (optab signed_op,
   return ::direct_optab_handler (op, mode);
 }
 
+/* Choose between signed and unsigned convert optabs SIGNED_OP and
+   UNSIGNED_OP based on the signedness of type suffix SUFFIX_I, then
+   pick the appropriate optab handler for "converting" from FROM_MODE
+   to TO_MODE.  */
+insn_code
+function_expander::convert_optab_handler_for_sign (optab signed_op,
+                                                  optab unsigned_op,
+                                                  unsigned int suffix_i,
+                                                  machine_mode to_mode,
+                                                  machine_mode from_mode)
+{
+  optab op = type_suffix (suffix_i).unsigned_p ? unsigned_op : signed_op;
+  return ::convert_optab_handler (op, to_mode, from_mode);
+}
+
 /* Return true if X overlaps any input.  */
 bool
 function_expander::overlaps_input_p (rtx x)
index e3880503da021911e91daf32a4498422eef8a906..645e56badbeb22243e101adc0c0aea39851de741 100644 (file)
@@ -660,6 +660,8 @@ public:
   insn_code direct_optab_handler (optab, unsigned int = 0);
   insn_code direct_optab_handler_for_sign (optab, optab, unsigned int = 0,
                                           machine_mode = E_VOIDmode);
+  insn_code convert_optab_handler_for_sign (optab, optab, unsigned int,
+                                           machine_mode, machine_mode);
 
   machine_mode result_mode () const;
 
index bfa28849adf827704bcc14e737179dd8e57194cc..f6c7c2f4cb31a8ce310df7eb52f5fb5f80ebad0e 100644 (file)
 ;; -------------------------------------------------------------------------
 
 ;; Four-element integer dot-product with accumulation.
-(define_insn "<sur>dot_prod<vsi2qi>"
+(define_insn "<sur>dot_prod<mode><vsi2qi>"
   [(set (match_operand:SVE_FULL_SDI 0 "register_operand")
        (plus:SVE_FULL_SDI
          (unspec:SVE_FULL_SDI
   }
 )
 
-(define_insn "@<sur>dot_prod<vsi2qi>"
+(define_insn "@<sur>dot_prod<mode><vsi2qi>"
   [(set (match_operand:VNx4SI_ONLY 0 "register_operand")
         (plus:VNx4SI_ONLY
          (unspec:VNx4SI_ONLY
     rtx ones = force_reg (<VSI2QI>mode, CONST1_RTX (<VSI2QI>mode));
     rtx diff = gen_reg_rtx (<VSI2QI>mode);
     emit_insn (gen_<su>abd<vsi2qi>3 (diff, operands[1], operands[2]));
-    emit_insn (gen_udot_prod<vsi2qi> (operands[0], diff, ones, operands[3]));
+    emit_insn (gen_udot_prod<mode><vsi2qi> (operands[0], diff, ones,
+                                           operands[3]));
     DONE;
   }
 )
index 972b03a4fef0b0bd4d50edf392bcfcb9acde551e..725092cc95f0c22130258496b0d04d16204b98dc 100644 (file)
 )
 
 ;; Two-way dot-product.
-(define_insn "@aarch64_sve_<sur>dotvnx4sivnx8hi"
+(define_insn "<sur>dot_prodvnx4sivnx8hi"
   [(set (match_operand:VNx4SI 0 "register_operand")
        (plus:VNx4SI
          (unspec:VNx4SI
diff --git a/gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c b/gcc/testsuite/gcc.target/aarch64/sme/vect-dotprod-twoway.c
new file mode 100644 (file)
index 0000000..77a019a
--- /dev/null
@@ -0,0 +1,26 @@
+/* { dg-additional-options "-O2 -ftree-vectorize" } */
+
+#include <stdint.h>
+#pragma GCC target "+sme2"
+
+uint32_t udot2(int n, uint16_t* data) __arm_streaming
+{
+  uint32_t sum = 0;
+  for (int i=0; i<n; i+=1) {
+    sum += data[i] * data[i];
+  }
+  return sum;
+}
+
+int32_t sdot2(int n, int16_t* data) __arm_streaming
+{
+  int32_t sum = 0;
+  for (int i=0; i<n; i+=1) {
+    sum += data[i] * data[i];
+  }
+  return sum;
+}
+
+/* { dg-final { scan-assembler-times {\tudot\tz[0-9]+\.s, z[0-9]+\.h, z[0-9]+\.h\n} 5 } } */
+/* { dg-final { scan-assembler-times {\tsdot\tz[0-9]+\.s, z[0-9]+\.h, z[0-9]+\.h\n} 5 } } */
+/* { dg-final { scan-assembler-times {\twhilelo\t} 4 } } */