]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
aarch64: Add support for fp8fma instructions
authorSaurabh Jha <saurabh.jha@arm.com>
Wed, 13 Nov 2024 17:16:37 +0000 (17:16 +0000)
committerSaurabh Jha <saurabh.jha@arm.com>
Thu, 14 Nov 2024 09:59:41 +0000 (09:59 +0000)
The AArch64 FEAT_FP8FMA extension introduces instructions for
multiply-add of vectors.

This patch introduces the following instructions:
1. {vmlalbq|vmlaltq}_f16_mf8_fpm.
2. {vmlalbq|vmlaltq}_lane{q}_f16_mf8_fpm.
3. {vmlallbbq|vmlallbtq|vmlalltbq|vmlallttq}_f32_mf8_fpm.
4. {vmlallbbq|vmlallbtq|vmlalltbq|vmlallttq}_lane{q}_f32_mf8_fpm.

It introduces the fp8fma flag.

gcc/ChangeLog:

* config/aarch64/aarch64-builtins.cc
(check_simd_lane_bounds): Add support for new unspecs.
(aarch64_expand_pragma_builtins): Add support for new unspecs.
* config/aarch64/aarch64-c.cc
(aarch64_update_cpp_builtins): New flags.
* config/aarch64/aarch64-option-extensions.def
(AARCH64_OPT_EXTENSION): New flags.
* config/aarch64/aarch64-simd-pragma-builtins.def
(ENTRY_FMA_FPM): Macro to declare fma intrinsics.
(REQUIRED_EXTENSIONS): Define to declare functions behind
command line flags.
* config/aarch64/aarch64-simd.md:
(@aarch64_<fpm_uns_op><VQ_HSF:mode><VQ_HSF:mode><V16QI_ONLY:mode><V16QI_ONLY:mode): Instruction pattern for fma intrinsics.
(@aarch64_<fpm_uns_op><VQ_HSF:mode><VQ_HSF:mode><V16QI_ONLY:mode><VB:mode><SI_ONLY:mode): Instruction pattern for fma intrinsics with lane.
* config/aarch64/aarch64.h
(TARGET_FP8FMA): New flag for fp8fma instructions.
* config/aarch64/iterators.md: New attributes and iterators.
* doc/invoke.texi: New flag for fp8fma instructions.

gcc/testsuite/ChangeLog:

* gcc.target/aarch64/simd/fma_fpm.c: New test.

gcc/config/aarch64/aarch64-builtins.cc
gcc/config/aarch64/aarch64-c.cc
gcc/config/aarch64/aarch64-option-extensions.def
gcc/config/aarch64/aarch64-simd-pragma-builtins.def
gcc/config/aarch64/aarch64-simd.md
gcc/config/aarch64/aarch64.h
gcc/config/aarch64/iterators.md
gcc/doc/invoke.texi
gcc/testsuite/gcc.target/aarch64/simd/fma_fpm.c [new file with mode: 0644]

index a71c8c9a64e9847dabfb9ac650bdce5eabc820a7..7b2decf671fa9bd096c42227bbdd81dba1e4749c 100644 (file)
@@ -2562,10 +2562,26 @@ check_simd_lane_bounds (location_t location, const aarch64_pragma_builtins_data
          = GET_MODE_NUNITS (vector_to_index_mode).to_constant ();
 
        auto low = 0;
-       int high
-         = builtin_data->unspec == UNSPEC_VDOT2
-         ? vector_to_index_mode_size / 2 - 1
-         : vector_to_index_mode_size / 4 - 1;
+       int high;
+       switch (builtin_data->unspec)
+         {
+         case UNSPEC_VDOT2:
+           high = vector_to_index_mode_size / 2 - 1;
+           break;
+         case UNSPEC_VDOT4:
+           high = vector_to_index_mode_size / 4 - 1;
+           break;
+         case UNSPEC_FMLALB:
+         case UNSPEC_FMLALT:
+         case UNSPEC_FMLALLBB:
+         case UNSPEC_FMLALLBT:
+         case UNSPEC_FMLALLTB:
+         case UNSPEC_FMLALLTT:
+           high = vector_to_index_mode_size - 1;
+           break;
+         default:
+           gcc_unreachable ();
+         }
        require_immediate_range (location, index_arg, low, high);
        break;
       }
@@ -3552,6 +3568,12 @@ aarch64_expand_pragma_builtin (tree exp, rtx target,
 
     case UNSPEC_VDOT2:
     case UNSPEC_VDOT4:
+    case UNSPEC_FMLALB:
+    case UNSPEC_FMLALT:
+    case UNSPEC_FMLALLBB:
+    case UNSPEC_FMLALLBT:
+    case UNSPEC_FMLALLTB:
+    case UNSPEC_FMLALLTT:
       if (builtin_data->signature == aarch64_builtin_signatures::ternary)
        icode = code_for_aarch64 (builtin_data->unspec,
                                  builtin_data->types[0].mode,
index ae1472e0fcf27f8dcfebb30e39d678604a84d615..03f912cde077ff82af9a61d07d1ee29fc56b0d9a 100644 (file)
@@ -264,6 +264,8 @@ aarch64_update_cpp_builtins (cpp_reader *pfile)
 
   aarch64_def_or_undef (TARGET_FP8DOT4, "__ARM_FEATURE_FP8DOT4", pfile);
 
+  aarch64_def_or_undef (TARGET_FP8FMA, "__ARM_FEATURE_FP8FMA", pfile);
+
   aarch64_def_or_undef (TARGET_LS64,
                        "__ARM_FEATURE_LS64", pfile);
   aarch64_def_or_undef (TARGET_RCPC, "__ARM_FEATURE_RCPC", pfile);
index 44d2e18d46bd9c1773424e05fbf45bbeb574eabb..8446d1bcd5dca409b8d8c0a5d7015bda4170e82a 100644 (file)
@@ -240,6 +240,8 @@ AARCH64_OPT_EXTENSION("fp8dot2", FP8DOT2, (SIMD), (), (), "fp8dot2")
 
 AARCH64_OPT_EXTENSION("fp8dot4", FP8DOT4, (SIMD), (), (), "fp8dot4")
 
+AARCH64_OPT_EXTENSION("fp8fma", FP8FMA, (SIMD), (), (), "fp8fma")
+
 AARCH64_OPT_EXTENSION("faminmax", FAMINMAX, (SIMD), (), (), "faminmax")
 
 #undef AARCH64_OPT_FMV_EXTENSION
index 4a94a6613f08d064e4af183e9cc6a10999b9f61e..c7857123ca03cd297c5800fd709d4374fc9d152a 100644 (file)
   ENTRY_TERNARY_FPM_LANE (vdotq_lane_##T##_mf8_fpm, T##q, T##q, f8q, f8, U) \
   ENTRY_TERNARY_FPM_LANE (vdotq_laneq_##T##_mf8_fpm, T##q, T##q, f8q, f8q, U)
 
+#undef ENTRY_FMA_FPM
+#define ENTRY_FMA_FPM(N, T, U)                                         \
+  ENTRY_TERNARY_FPM (N##_##T##_mf8_fpm, T##q, T##q, f8q, f8q, U)       \
+  ENTRY_TERNARY_FPM_LANE (N##_lane_##T##_mf8_fpm, T##q, T##q, f8q, f8, U) \
+  ENTRY_TERNARY_FPM_LANE (N##_laneq_##T##_mf8_fpm, T##q, T##q, f8q, f8q, U)
+
 #undef ENTRY_VHSDF
 #define ENTRY_VHSDF(NAME, UNSPEC) \
   ENTRY_BINARY (NAME##_f16, f16, f16, f16, UNSPEC)             \
@@ -106,3 +112,13 @@ ENTRY_VDOT_FPM (f16, UNSPEC_VDOT2)
 #define REQUIRED_EXTENSIONS nonstreaming_only (AARCH64_FL_FP8DOT4)
 ENTRY_VDOT_FPM (f32, UNSPEC_VDOT4)
 #undef REQUIRED_EXTENSIONS
+
+// fp8 multiply-add
+#define REQUIRED_EXTENSIONS nonstreaming_only (AARCH64_FL_FP8FMA)
+ENTRY_FMA_FPM (vmlalbq, f16, UNSPEC_FMLALB)
+ENTRY_FMA_FPM (vmlaltq, f16, UNSPEC_FMLALT)
+ENTRY_FMA_FPM (vmlallbbq, f32, UNSPEC_FMLALLBB)
+ENTRY_FMA_FPM (vmlallbtq, f32, UNSPEC_FMLALLBT)
+ENTRY_FMA_FPM (vmlalltbq, f32, UNSPEC_FMLALLTB)
+ENTRY_FMA_FPM (vmlallttq, f32, UNSPEC_FMLALLTT)
+#undef REQUIRED_EXTENSIONS
index 7b974865f5559098c73402eb4bb1032e3dbaf7bd..df0d30af6a118dabc29bf3c0c0aa771ae90b19d2 100644 (file)
   "TARGET_FP8DOT4"
   "<fpm_uns_op>\t%1.<VDQSF:Vtype>, %2.<VB:Vtype>, %3.<VDQSF:Vdotlanetype>[%4]"
 )
+
+;; fpm fma instructions.
+(define_insn
+  "@aarch64_<fpm_uns_op><VQ_HSF:mode><VQ_HSF:mode><V16QI_ONLY:mode><V16QI_ONLY:mode>"
+  [(set (match_operand:VQ_HSF 0 "register_operand" "=w")
+       (unspec:VQ_HSF
+        [(match_operand:VQ_HSF 1 "register_operand" "w")
+         (match_operand:V16QI_ONLY 2 "register_operand" "w")
+         (match_operand:V16QI_ONLY 3 "register_operand" "w")
+         (reg:DI FPM_REGNUM)]
+       FPM_FMA_UNS))]
+  "TARGET_FP8FMA"
+  "<fpm_uns_op>\t%1.<VQ_HSF:Vtype>, %2.<V16QI_ONLY:Vtype>, %3.<V16QI_ONLY:Vtype>"
+)
+
+;; fpm fma instructions with lane.
+(define_insn
+  "@aarch64_<fpm_uns_op><VQ_HSF:mode><VQ_HSF:mode><V16QI_ONLY:mode><VB:mode><SI_ONLY:mode>"
+  [(set (match_operand:VQ_HSF 0 "register_operand" "=w")
+       (unspec:VQ_HSF
+        [(match_operand:VQ_HSF 1 "register_operand" "w")
+         (match_operand:V16QI_ONLY 2 "register_operand" "w")
+         (match_operand:VB 3 "register_operand" "w")
+         (match_operand:SI_ONLY 4 "const_int_operand" "n")
+         (reg:DI FPM_REGNUM)]
+       FPM_FMA_UNS))]
+  "TARGET_FP8FMA"
+  "<fpm_uns_op>\t%1.<VQ_HSF:Vtype>, %2.<V16QI_ONLY:Vtype>, %3.b[%4]"
+)
index c50a578731a5bd9cd09a3bf6dd8d9fe07200b2bd..a691a0f2b181d1527c7ed4a0a2a80e0772037aac 100644 (file)
@@ -500,6 +500,9 @@ constexpr auto AARCH64_FL_DEFAULT_ISA_MODE ATTRIBUTE_UNUSED
 /* fp8 dot product instructions are enabled through +fp8dot4.  */
 #define TARGET_FP8DOT4 AARCH64_HAVE_ISA (FP8DOT4)
 
+/* fp8 multiply-add instructions are enabled through +fp8fma.  */
+#define TARGET_FP8FMA AARCH64_HAVE_ISA (FP8FMA)
+
 /* Standard register usage.  */
 
 /* 31 64-bit general purpose registers R0-R30:
index 8c03dcd14dd15f4150852f67433bed73b244da1c..82dc7dcf7621e9fd74b0d4f2534ae427bd2e9064 100644 (file)
     UNSPEC_FMINNMV     ; Used in aarch64-simd.md.
     UNSPEC_FMINV       ; Used in aarch64-simd.md.
     UNSPEC_FADDV       ; Used in aarch64-simd.md.
+    UNSPEC_FMLALLBB    ; Used in aarch64-simd.md.
+    UNSPEC_FMLALLBT    ; Used in aarch64-simd.md.
+    UNSPEC_FMLALLTB    ; Used in aarch64-simd.md.
+    UNSPEC_FMLALLTT    ; Used in aarch64-simd.md.
     UNSPEC_FNEG                ; Used in aarch64-simd.md.
     UNSPEC_FSCALE      ; Used in aarch64-simd.md.
     UNSPEC_ADDV                ; Used in aarch64-simd.md.
 (define_int_iterator FPM_VDOT2_UNS [UNSPEC_VDOT2])
 (define_int_iterator FPM_VDOT4_UNS [UNSPEC_VDOT4])
 
+(define_int_iterator FPM_FMA_UNS
+  [UNSPEC_FMLALB
+   UNSPEC_FMLALT
+   UNSPEC_FMLALLBB
+   UNSPEC_FMLALLBT
+   UNSPEC_FMLALLTB
+   UNSPEC_FMLALLTT])
+
 (define_int_attr fpm_uns_op
   [(UNSPEC_FSCALE "fscale")
    (UNSPEC_VCVT "fcvtn")
    (UNSPEC_VCVT_HIGH "fcvtn2")
+   (UNSPEC_FMLALB "fmlalb")
+   (UNSPEC_FMLALT "fmlalt")
+   (UNSPEC_FMLALLBB "fmlallbb")
+   (UNSPEC_FMLALLBT "fmlallbt")
+   (UNSPEC_FMLALLTB "fmlalltb")
+   (UNSPEC_FMLALLTT "fmlalltt")
    (UNSPEC_VDOT2 "fdot")
    (UNSPEC_VDOT4 "fdot")])
index bc3f74234259354290f5a10bb3691759823f4815..d41136bebc1ca5bb25441d36d3881d78962a7505 100644 (file)
@@ -21811,6 +21811,8 @@ Enable the fp8 (8-bit floating point) extension.
 Enable the fp8dot2 (8-bit floating point dot product) extension.
 @item fp8dot4
 Enable the fp8dot4 (8-bit floating point dot product) extension.
+@item fp8fma
+Enable the fp8fma (8-bit floating point multiply-add) extension.
 @item faminmax
 Enable the Floating Point Absolute Maximum/Minimum extension.
 
diff --git a/gcc/testsuite/gcc.target/aarch64/simd/fma_fpm.c b/gcc/testsuite/gcc.target/aarch64/simd/fma_fpm.c
new file mode 100644 (file)
index 0000000..ea21856
--- /dev/null
@@ -0,0 +1,221 @@
+/* { dg-do compile } */
+/* { dg-additional-options "-O3 -march=armv9-a+fp8fma" } */
+/* { dg-final { check-function-bodies "**" "" } } */
+
+#include "arm_neon.h"
+
+/*
+** test_vmlalbq_f16_fpm:
+**     msr     fpmr, x0
+**     fmlalb  v0.8h, v1.16b, v2.16b
+**     ret
+*/
+float16x8_t
+test_vmlalbq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlalbq_f16_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlaltq_f16_fpm:
+**     msr     fpmr, x0
+**     fmlalt  v0.8h, v1.16b, v2.16b
+**     ret
+*/
+float16x8_t
+test_vmlaltq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlaltq_f16_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlallbbq_f32_fpm:
+**     msr     fpmr, x0
+**     fmlallbb        v0.4s, v1.16b, v2.16b
+**     ret
+*/
+float32x4_t
+test_vmlallbbq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallbbq_f32_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlallbtq_f32_fpm:
+**     msr     fpmr, x0
+**     fmlallbt        v0.4s, v1.16b, v2.16b
+**     ret
+*/
+float32x4_t
+test_vmlallbtq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallbtq_f32_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlalltbq_f32_fpm:
+**     msr     fpmr, x0
+**     fmlalltb        v0.4s, v1.16b, v2.16b
+**     ret
+*/
+float32x4_t
+test_vmlalltbq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlalltbq_f32_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlallttq_f32_fpm:
+**     msr     fpmr, x0
+**     fmlalltt        v0.4s, v1.16b, v2.16b
+**     ret
+*/
+float32x4_t
+test_vmlallttq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallttq_f32_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlalbq_lane_f16_fpm:
+**     msr     fpmr, x0
+**     fmlalb  v0.8h, v1.16b, v2.b\[1\]
+**     ret
+*/
+float16x8_t
+test_vmlalbq_lane_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlalbq_lane_f16_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlalbq_laneq_f16_fpm:
+**     msr     fpmr, x0
+**     fmlalb  v0.8h, v1.16b, v2.b\[1\]
+**     ret
+*/
+float16x8_t
+test_vmlalbq_laneq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlalbq_laneq_f16_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlaltq_lane_f16_fpm:
+**     msr     fpmr, x0
+**     fmlalt  v0.8h, v1.16b, v2.b\[1\]
+**     ret
+*/
+float16x8_t
+test_vmlaltq_lane_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlaltq_lane_f16_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlaltq_laneq_f16_fpm:
+**     msr     fpmr, x0
+**     fmlalt  v0.8h, v1.16b, v2.b\[1\]
+**     ret
+*/
+float16x8_t
+test_vmlaltq_laneq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlaltq_laneq_f16_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallbbq_lane_f32_fpm:
+**     msr     fpmr, x0
+**     fmlallbb        v0.4s, v1.16b, v2.b\[1\]
+**     ret
+*/
+float32x4_t
+test_vmlallbbq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlallbbq_lane_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallbbq_laneq_f32_fpm:
+**     msr     fpmr, x0
+**     fmlallbb        v0.4s, v1.16b, v2.b\[1\]
+**     ret
+*/
+float32x4_t
+test_vmlallbbq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallbbq_laneq_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallbtq_lane_f32_fpm:
+**     msr     fpmr, x0
+**     fmlallbt        v0.4s, v1.16b, v2.b\[1\]
+**     ret
+*/
+float32x4_t
+test_vmlallbtq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlallbtq_lane_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallbtq_laneq_f32_fpm:
+**     msr     fpmr, x0
+**     fmlallbt        v0.4s, v1.16b, v2.b\[1\]
+**     ret
+*/
+float32x4_t
+test_vmlallbtq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallbtq_laneq_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlalltbq_lane_f32_fpm:
+**     msr     fpmr, x0
+**     fmlalltb        v0.4s, v1.16b, v2.b\[1\]
+**     ret
+*/
+float32x4_t
+test_vmlalltbq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlalltbq_lane_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlalltbq_laneq_f32_fpm:
+**     msr     fpmr, x0
+**     fmlalltb        v0.4s, v1.16b, v2.b\[1\]
+**     ret
+*/
+float32x4_t
+test_vmlalltbq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlalltbq_laneq_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallttq_lane_f32_fpm:
+**     msr     fpmr, x0
+**     fmlalltt        v0.4s, v1.16b, v2.b\[1\]
+**     ret
+*/
+float32x4_t
+test_vmlallttq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlallttq_lane_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallttq_laneq_f32_fpm:
+**     msr     fpmr, x0
+**     fmlalltt        v0.4s, v1.16b, v2.b\[1\]
+**     ret
+*/
+float32x4_t
+test_vmlallttq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallttq_laneq_f32_mf8_fpm (a, b, c, 1, d);
+}