]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
vect: Fix operand swapping on complex multiplication detection [PR122408]
authorTamar Christina <tamar.christina@arm.com>
Mon, 27 Oct 2025 17:55:38 +0000 (17:55 +0000)
committerTamar Christina <tamar.christina@arm.com>
Fri, 31 Oct 2025 16:22:38 +0000 (16:22 +0000)
For

SUBROUTINE a( j, b, c, d )
  !GCC$ ATTRIBUTES noinline :: a
  COMPLEX*16         b
  COMPLEX*16         c( * ), d( * )
  DO k = 1, j
     c( k ) = - b * CONJG( d( k ) )
  END DO
END

we incorrectly generate .IFN_COMPLEX_MUL instead of .IFN_COMPLEX_MUL_CONJ.

The issue happens because in the call to vect_validate_multiplication the
operand vectors are passed by reference and so the stripping of the NEGATE_EXPR
after matching modifies the input vector.  If validation fail we flip the
operands and try again.  But we've already stipped the negates and so if we
match we would match a normal multiply.

This fixes the API by marking the operands as const and instead pass an explicit
output vec that's to be used.  This also reduces the number of copies we were
doing.

With this we now correctly detect .IFN_COMPLEX_MUL_CONJ.  Weirdly enough I
couldn't reproduce this with any C example because they get reassociated
differently and always succeed on the first attempt.  Fortran is easy to
trigger though so new fortran tests added.

gcc/ChangeLog:

PR tree-optimization/122408
* tree-vect-slp-patterns.cc (vect_validate_multiplication): Cleanup and
document interface.
(complex_mul_pattern::matches, complex_fms_pattern::matches): Update to
new interface.

gcc/testsuite/ChangeLog:

PR tree-optimization/122408
* gfortran.target/aarch64/pr122408_1.f90: New test.
* gfortran.target/aarch64/pr122408_2.f90: New test.

(cherry picked from commit c5fa3d4c88fc4f8799318e463c47941eb52b7546)

gcc/testsuite/gfortran.target/aarch64/pr122408_1.f90 [new file with mode: 0644]
gcc/testsuite/gfortran.target/aarch64/pr122408_2.f90 [new file with mode: 0644]
gcc/tree-vect-slp-patterns.cc

diff --git a/gcc/testsuite/gfortran.target/aarch64/pr122408_1.f90 b/gcc/testsuite/gfortran.target/aarch64/pr122408_1.f90
new file mode 100644 (file)
index 0000000..8a34162
--- /dev/null
@@ -0,0 +1,61 @@
+! { dg-do compile }
+! { dg-additional-options "-O2 -march=armv8.3-a" }
+
+subroutine c_add_ab(n, a, c, b)         ! C += A * B
+  use iso_fortran_env, only: real64
+  implicit none
+  !GCC$ ATTRIBUTES noinline :: c_add_ab
+  integer, intent(in) :: n
+  complex(real64), intent(in)    :: a
+  complex(real64), intent(inout) :: c(*)
+  complex(real64), intent(in)    :: b(*)
+  integer :: k
+  do k = 1, n
+    c(k) = c(k) + a * b(k)
+  end do
+end subroutine c_add_ab
+
+subroutine c_sub_ab(n, a, c, b)         ! C -= A * B
+  use iso_fortran_env, only: real64
+  implicit none
+  !GCC$ ATTRIBUTES noinline :: c_sub_ab
+  integer, intent(in) :: n
+  complex(real64), intent(in)    :: a
+  complex(real64), intent(inout) :: c(*)
+  complex(real64), intent(in)    :: b(*)
+  integer :: k
+  do k = 1, n
+    c(k) = c(k) - a * b(k)
+  end do
+end subroutine c_sub_ab
+
+subroutine c_add_a_conjb(n, a, c, b)    ! C += A * conj(B)
+  use iso_fortran_env, only: real64
+  implicit none
+  !GCC$ ATTRIBUTES noinline :: c_add_a_conjb
+  integer, intent(in) :: n
+  complex(real64), intent(in)    :: a
+  complex(real64), intent(inout) :: c(*)
+  complex(real64), intent(in)    :: b(*)
+  integer :: k
+  do k = 1, n
+    c(k) = c(k) + a * conjg(b(k))
+  end do
+end subroutine c_add_a_conjb
+
+subroutine c_sub_a_conjb(n, a, c, b)    ! C -= A * conj(B)
+  use iso_fortran_env, only: real64
+  implicit none
+  !GCC$ ATTRIBUTES noinline :: c_sub_a_conjb
+  integer, intent(in) :: n
+  complex(real64), intent(in)    :: a
+  complex(real64), intent(inout) :: c(*)
+  complex(real64), intent(in)    :: b(*)
+  integer :: k
+  do k = 1, n
+    c(k) = c(k) - a * conjg(b(k))
+  end do
+end subroutine c_sub_a_conjb
+
+! { dg-final { scan-assembler-times {fcmla\s+v[0-9]+.2d, v[0-9]+.2d, v[0-9]+.2d, #0} 2 } }
+! { dg-final { scan-assembler-times {fcmla\s+v[0-9]+.2d, v[0-9]+.2d, v[0-9]+.2d, #270} 2 } }
diff --git a/gcc/testsuite/gfortran.target/aarch64/pr122408_2.f90 b/gcc/testsuite/gfortran.target/aarch64/pr122408_2.f90
new file mode 100644 (file)
index 0000000..feb6dc1
--- /dev/null
@@ -0,0 +1,140 @@
+! { dg-do run }
+! { dg-additional-options "-O2" }
+! { dg-additional-options "-O2 -march=armv8.3-a" { target arm_v8_3a_complex_neon_hw } }
+
+module util
+  use iso_fortran_env, only: real64, int64
+  implicit none
+contains
+  pure logical function bitwise_eq(x, y)
+    complex(real64), intent(in) :: x, y
+    integer(int64) :: xr, xi, yr, yi
+    xr = transfer(real(x,kind=real64), 0_int64)
+    xi = transfer(aimag(x),             0_int64)
+    yr = transfer(real(y,kind=real64),  0_int64)
+    yi = transfer(aimag(y),              0_int64)
+    bitwise_eq = (xr == yr) .and. (xi == yi)
+  end function bitwise_eq
+
+  subroutine check_equal(tag, got, ref, nfail)
+    character(*), intent(in) :: tag
+    complex(real64), intent(in) :: got(:), ref(:)
+    integer, intent(inout) :: nfail
+    integer :: i
+    do i = 1, size(got)
+      if (.not. bitwise_eq(got(i), ref(i))) then
+        nfail = nfail + 1
+        write(*,'(A,": mismatch at i=",I0, "  got=",2ES16.8,"  ref=",2ES16.8)') &
+             trim(tag), i, real(got(i)), aimag(got(i)), real(ref(i)), aimag(ref(i))
+      end if
+    end do
+  end subroutine check_equal
+end module util
+
+module fcmla_ops
+  use iso_fortran_env, only: real64
+  implicit none
+contains
+  subroutine c_add_ab(n, a, c, b)         ! C += A * B
+    !GCC$ ATTRIBUTES noinline :: c_add_ab
+    integer, intent(in) :: n
+    complex(real64), intent(in)    :: a
+    complex(real64), intent(inout) :: c(*)
+    complex(real64), intent(in)    :: b(*)
+    integer :: k
+    do k = 1, n
+      c(k) = c(k) + a * b(k)
+    end do
+  end subroutine c_add_ab
+
+  subroutine c_sub_ab(n, a, c, b)         ! C -= A * B
+    !GCC$ ATTRIBUTES noinline :: c_sub_ab
+    integer, intent(in) :: n
+    complex(real64), intent(in)    :: a
+    complex(real64), intent(inout) :: c(*)
+    complex(real64), intent(in)    :: b(*)
+    integer :: k
+    do k = 1, n
+      c(k) = c(k) - a * b(k)
+    end do
+  end subroutine c_sub_ab
+
+  subroutine c_add_a_conjb(n, a, c, b)    ! C += A * conj(B)
+    !GCC$ ATTRIBUTES noinline :: c_add_a_conjb
+    integer, intent(in) :: n
+    complex(real64), intent(in)    :: a
+    complex(real64), intent(inout) :: c(*)
+    complex(real64), intent(in)    :: b(*)
+    integer :: k
+    do k = 1, n
+      c(k) = c(k) + a * conjg(b(k))
+    end do
+  end subroutine c_add_a_conjb
+
+  subroutine c_sub_a_conjb(n, a, c, b)    ! C -= A * conj(B)
+    !GCC$ ATTRIBUTES noinline :: c_sub_a_conjb
+    integer, intent(in) :: n
+    complex(real64), intent(in)    :: a
+    complex(real64), intent(inout) :: c(*)
+    complex(real64), intent(in)    :: b(*)
+    integer :: k
+    do k = 1, n
+      c(k) = c(k) - a * conjg(b(k))
+    end do
+  end subroutine c_sub_a_conjb
+end module fcmla_ops
+
+program fcmla_accum_pairs
+  use iso_fortran_env, only: real64
+  use util
+  use fcmla_ops
+  implicit none
+
+  integer, parameter :: n = 4
+  complex(real64) :: a, b(n), c0(n)
+  complex(real64) :: c_add_ab_got(n),      c_add_ab_ref(n)
+  complex(real64) :: c_sub_ab_got(n),      c_sub_ab_ref(n)
+  complex(real64) :: c_add_conjb_got(n),   c_add_conjb_ref(n)
+  complex(real64) :: c_sub_conjb_got(n),   c_sub_conjb_ref(n)
+  integer :: i, fails
+
+  ! Constants (include a signed-zero lane)
+  a    = cmplx( 2.0_real64, -3.0_real64, kind=real64)
+  b(1) = cmplx( 1.5_real64, -2.0_real64, kind=real64)
+  b(2) = cmplx(-4.0_real64,  5.0_real64, kind=real64)
+  b(3) = cmplx(-0.0_real64,  0.0_real64, kind=real64)
+  b(4) = cmplx( 0.25_real64, 3.0_real64, kind=real64)
+
+  c0(1) = cmplx( 1.0_real64, -2.0_real64, kind=real64)
+  c0(2) = cmplx( 3.0_real64, -4.0_real64, kind=real64)
+  c0(3) = cmplx(-5.0_real64,  6.0_real64, kind=real64)
+  c0(4) = cmplx( 0.0_real64,  0.0_real64, kind=real64)
+
+  ! Run each form
+  c_add_ab_got    = c0; call c_add_ab     (n, a, c_add_ab_got,    b)
+  c_sub_ab_got    = c0; call c_sub_ab     (n, a, c_sub_ab_got,    b)
+  c_add_conjb_got = c0; call c_add_a_conjb(n, a, c_add_conjb_got, b)
+  c_sub_conjb_got = c0; call c_sub_a_conjb(n, a, c_sub_conjb_got, b)
+
+  ! Scalar references
+  do i = 1, n
+    c_add_ab_ref(i)    = c0(i) + a * b(i)
+    c_sub_ab_ref(i)    = c0(i) - a * b(i)
+    c_add_conjb_ref(i) = c0(i) + a * conjg(b(i))
+    c_sub_conjb_ref(i) = c0(i) - a * conjg(b(i))
+  end do
+
+  ! Bitwise checks
+  fails = 0
+  call check_equal("C +=  A*B       ", c_add_ab_got,    c_add_ab_ref,    fails)
+  call check_equal("C -=  A*B       ", c_sub_ab_got,    c_sub_ab_ref,    fails)
+  call check_equal("C +=  A*conj(B) ", c_add_conjb_got, c_add_conjb_ref, fails)
+  call check_equal("C -=  A*conj(B) ", c_sub_conjb_got, c_sub_conjb_ref, fails)
+
+  if (fails == 0) then
+    stop 0
+  else
+    stop 1
+  end if
+end program fcmla_accum_pairs
+
index c0dff90d9baf5ffaf0902326acc00ca78e1b563c..cebd9aa1c13c9af49374672fd85b827c58364c33 100644 (file)
@@ -847,15 +847,23 @@ compatible_complex_nodes_p (slp_compat_nodes_map_t *compat_cache,
   return true;
 }
 
+
+/* Check to see if the oprands to two multiplies, 2 each in LEFT_OP and
+   RIGHT_OP match a complex multiplication  or complex multiply-and-accumulate
+   or complex multiply-and-subtract pattern.  Do this using the permute cache
+   PERM_CACHE and the combination compatibility list COMPAT_CACHE.  If
+   the operation is successful the macthing operands are returned in OPS and
+   _STATUS indicates if the operation matched includes a conjugate of one of the
+   operands.  If the operation succeeds True is returned, otherwise False and
+   the values in ops are meaningless.  */
 static inline bool
 vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
                              slp_compat_nodes_map_t *compat_cache,
-                             vec<slp_tree> &left_op,
-                             vec<slp_tree> &right_op,
-                             bool subtract,
+                             const vec<slp_tree> &left_op,
+                             const vec<slp_tree> &right_op,
+                             bool subtract, vec<slp_tree> &ops,
                              enum _conj_status *_status)
 {
-  auto_vec<slp_tree> ops;
   enum _conj_status stats = CONJ_NONE;
 
   /* The complex operations can occur in two layouts and two permute sequences
@@ -886,31 +894,31 @@ vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
   bool neg0 = vect_match_expression_p (right_op[0], NEGATE_EXPR);
   bool neg1 = vect_match_expression_p (right_op[1], NEGATE_EXPR);
 
+  /* Create the combined inputs after remapping and flattening.  */
+  ops.create (4);
+  ops.safe_splice (left_op);
+  ops.safe_splice (right_op);
+
   /* Determine which style we're looking at.  We only have different ones
      whenever a conjugate is involved.  */
   if (neg0 && neg1)
     ;
   else if (neg0)
     {
-      right_op[0] = SLP_TREE_CHILDREN (right_op[0])[0];
+      ops[2] = SLP_TREE_CHILDREN (right_op[0])[0];
       stats = CONJ_FST;
       if (subtract)
        perm = 0;
     }
   else if (neg1)
     {
-      right_op[1] = SLP_TREE_CHILDREN (right_op[1])[0];
+      ops[3] = SLP_TREE_CHILDREN (right_op[1])[0];
       stats = CONJ_SND;
       perm = 1;
     }
 
   *_status = stats;
 
-  /* Flatten the inputs after we've remapped them.  */
-  ops.create (4);
-  ops.safe_splice (left_op);
-  ops.safe_splice (right_op);
-
   /* Extract out the elements to check.  */
   slp_tree op0 = ops[styles[style][0]];
   slp_tree op1 = ops[styles[style][1]];
@@ -1073,15 +1081,16 @@ complex_mul_pattern::matches (complex_operation_t op,
     return IFN_LAST;
 
   enum _conj_status status;
+  auto_vec<slp_tree> res_ops;
   if (!vect_validate_multiplication (perm_cache, compat_cache, left_op,
-                                    right_op, false, &status))
+                                    right_op, false, res_ops, &status))
     {
       /* Try swapping the order and re-trying since multiplication is
         commutative.  */
       std::swap (left_op[0], left_op[1]);
       std::swap (right_op[0], right_op[1]);
       if (!vect_validate_multiplication (perm_cache, compat_cache, left_op,
-                                        right_op, false, &status))
+                                        right_op, false, res_ops, &status))
        return IFN_LAST;
     }
 
@@ -1109,24 +1118,24 @@ complex_mul_pattern::matches (complex_operation_t op,
   if (add0)
     ops->quick_push (add0);
 
-  complex_perm_kinds_t kind = linear_loads_p (perm_cache, left_op[0]);
+  complex_perm_kinds_t kind = linear_loads_p (perm_cache, res_ops[0]);
   if (kind == PERM_EVENODD || kind == PERM_TOP)
     {
-      ops->quick_push (left_op[1]);
-      ops->quick_push (right_op[1]);
-      ops->quick_push (left_op[0]);
+      ops->quick_push (res_ops[1]);
+      ops->quick_push (res_ops[3]);
+      ops->quick_push (res_ops[0]);
     }
   else if (kind == PERM_EVENEVEN && status != CONJ_SND)
     {
-      ops->quick_push (left_op[0]);
-      ops->quick_push (right_op[0]);
-      ops->quick_push (left_op[1]);
+      ops->quick_push (res_ops[0]);
+      ops->quick_push (res_ops[2]);
+      ops->quick_push (res_ops[1]);
     }
   else
     {
-      ops->quick_push (left_op[0]);
-      ops->quick_push (right_op[1]);
-      ops->quick_push (left_op[1]);
+      ops->quick_push (res_ops[0]);
+      ops->quick_push (res_ops[3]);
+      ops->quick_push (res_ops[1]);
     }
 
   return ifn;
@@ -1298,15 +1307,17 @@ complex_fms_pattern::matches (complex_operation_t op,
     return IFN_LAST;
 
   enum _conj_status status;
+  auto_vec<slp_tree> res_ops;
   if (!vect_validate_multiplication (perm_cache, compat_cache, right_op,
-                                    left_op, true, &status))
+                                    left_op, true, res_ops, &status))
     {
       /* Try swapping the order and re-trying since multiplication is
         commutative.  */
       std::swap (left_op[0], left_op[1]);
       std::swap (right_op[0], right_op[1]);
+      auto_vec<slp_tree> res_ops;
       if (!vect_validate_multiplication (perm_cache, compat_cache, right_op,
-                                        left_op, true, &status))
+                                        left_op, true, res_ops, &status))
        return IFN_LAST;
     }
 
@@ -1321,20 +1332,20 @@ complex_fms_pattern::matches (complex_operation_t op,
   ops->truncate (0);
   ops->create (4);
 
-  complex_perm_kinds_t kind = linear_loads_p (perm_cache, right_op[0]);
+  complex_perm_kinds_t kind = linear_loads_p (perm_cache, res_ops[2]);
   if (kind == PERM_EVENODD)
     {
       ops->quick_push (l0node[0]);
-      ops->quick_push (right_op[0]);
-      ops->quick_push (right_op[1]);
-      ops->quick_push (left_op[1]);
+      ops->quick_push (res_ops[2]);
+      ops->quick_push (res_ops[3]);
+      ops->quick_push (res_ops[1]);
     }
   else
     {
       ops->quick_push (l0node[0]);
-      ops->quick_push (right_op[1]);
-      ops->quick_push (right_op[0]);
-      ops->quick_push (left_op[0]);
+      ops->quick_push (res_ops[3]);
+      ops->quick_push (res_ops[2]);
+      ops->quick_push (res_ops[0]);
     }
 
   return ifn;