]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
middle-end: simplify complex if expressions where comparisons are inverse of one...
authorTamar Christina <tamar.christina@arm.com>
Mon, 12 Dec 2022 15:21:39 +0000 (15:21 +0000)
committerTamar Christina <tamar.christina@arm.com>
Mon, 12 Dec 2022 15:21:39 +0000 (15:21 +0000)
This optimizes the following sequence

  ((a < b) & c) | ((a >= b) & d)

into

  (a < b ? c : d) & 1

for scalar and on vector we can omit the & 1.

Also recognizes

  (-(a < b) & c) | (-(a >= b) & d)

into

  a < b ? c : d

This changes the code generation from

zoo2:
cmp     w0, w1
cset    w0, lt
cset    w1, ge
and     w0, w0, w2
and     w1, w1, w3
orr     w0, w0, w1
ret

into

cmp w0, w1
csel w0, w2, w3, lt
and w0, w0, 1
ret

and significantly reduces the number of selects we have to do in the vector
code.

gcc/ChangeLog:

* match.pd: Add new rule.

gcc/testsuite/ChangeLog:

* gcc.target/aarch64/if-compare_1.c: New test.
* gcc.target/aarch64/if-compare_2.c: New test.

gcc/match.pd
gcc/testsuite/gcc.target/aarch64/if-compare_1.c [new file with mode: 0644]
gcc/testsuite/gcc.target/aarch64/if-compare_2.c [new file with mode: 0644]

index fdba5833beb8f0029a2b11cae0ac4d85e4118f85..c48fe2d6bf20d84fc349259e8a370c7acc4cda9a 100644 (file)
@@ -1906,6 +1906,61 @@ DEFINE_INT_AND_FLOAT_ROUND_FN (RINT)
  (if (INTEGRAL_TYPE_P (type))
   (bit_and @0 @1)))
 
+(for cmp (tcc_comparison)
+     icmp (inverted_tcc_comparison)
+ /* Fold (((a < b) & c) | ((a >= b) & d)) into (a < b ? c : d) & 1.  */
+ (simplify
+  (bit_ior
+   (bit_and:c (convert? (cmp@0  @01 @02)) @3)
+   (bit_and:c (convert? (icmp@4 @01 @02)) @5))
+    (if (INTEGRAL_TYPE_P (type)
+        /* The scalar version has to be canonicalized after vectorization
+           because it makes unconditional loads conditional ones, which
+           means we lose vectorization because the loads may trap.  */
+        && canonicalize_math_after_vectorization_p ())
+     (bit_and (cond @0 @3 @5) { build_one_cst (type); })))
+
+ /* Fold ((-(a < b) & c) | (-(a >= b) & d)) into a < b ? c : d.  This is
+    canonicalized further and we recognize the conditional form:
+    (a < b ? c : 0) | (a >= b ? d : 0) into a < b ? c : d.  */
+ (simplify
+  (bit_ior
+   (cond (cmp@0  @01 @02) @3 zerop)
+   (cond (icmp@4 @01 @02) @5 zerop))
+    (if (INTEGRAL_TYPE_P (type)
+        /* The scalar version has to be canonicalized after vectorization
+           because it makes unconditional loads conditional ones, which
+           means we lose vectorization because the loads may trap.  */
+        && canonicalize_math_after_vectorization_p ())
+    (cond @0 @3 @5)))
+
+ /* Vector Fold (((a < b) & c) | ((a >= b) & d)) into a < b ? c : d. 
+    and ((~(a < b) & c) | (~(a >= b) & d)) into a < b ? c : d.  */
+ (simplify
+  (bit_ior
+   (bit_and:c (vec_cond:s (cmp@0 @6 @7) @4 @5) @2)
+   (bit_and:c (vec_cond:s (icmp@1 @6 @7) @4 @5) @3))
+    (if (integer_zerop (@5))
+     (switch
+      (if (integer_onep (@4))
+       (bit_and (vec_cond @0 @2 @3) @4))
+       (if (integer_minus_onep (@4))
+        (vec_cond @0 @2 @3)))
+    (if (integer_zerop (@4))
+     (switch
+      (if (integer_onep (@5))
+       (bit_and (vec_cond @0 @3 @2) @5))
+      (if (integer_minus_onep (@5))
+       (vec_cond @0 @3 @2))))))
+
+ /* Scalar Vectorized Fold ((-(a < b) & c) | (-(a >= b) & d))
+    into a < b ? d : c.  */
+ (simplify
+  (bit_ior
+   (vec_cond:s (cmp@0 @4 @5) @2 integer_zerop)
+   (vec_cond:s (icmp@1 @4 @5) @3 integer_zerop))
+    (vec_cond @0 @2 @3)))
+
 /* Transform X & -Y into X * Y when Y is { 0 or 1 }.  */
 (simplify
  (bit_and:c (convert? (negate zero_one_valued_p@0)) @1)
diff --git a/gcc/testsuite/gcc.target/aarch64/if-compare_1.c b/gcc/testsuite/gcc.target/aarch64/if-compare_1.c
new file mode 100644 (file)
index 0000000..53bbd77
--- /dev/null
@@ -0,0 +1,47 @@
+/* { dg-do run } */
+/* { dg-additional-options "-O -save-temps" } */
+/* { dg-final { check-function-bodies "**" "" "" { target { le } } } } */
+
+extern void abort ();
+
+/*
+**zoo1:
+**     cmp     w0, w1
+**     csel    w0, w2, w3, lt
+**     and     w0, w0, 1
+**     ret
+*/
+__attribute((noipa, noinline))
+int zoo1 (int a, int b, int c, int d)
+{
+   return ((a < b) & c) | ((a >= b) & d);
+}
+
+/*
+**zoo2:
+**     cmp     w0, w1
+**     csel    w0, w2, w3, lt
+**     ret
+*/
+__attribute((noipa, noinline))
+int zoo2 (int a, int b, int c, int d)
+{
+   return (-(a < b) & c) | (-(a >= b) & d);
+}
+
+int main ()
+{
+  if (zoo1 (-3, 3, 5, 8) != 1)
+    abort ();
+
+  if (zoo1 (3, -3, 5, 8) != 0)
+    abort ();
+
+  if (zoo2 (-3, 3, 5, 8) != 5)
+    abort ();
+
+  if (zoo2 (3, -3, 5, 8) != 8)
+    abort ();
+
+  return 0;
+}
diff --git a/gcc/testsuite/gcc.target/aarch64/if-compare_2.c b/gcc/testsuite/gcc.target/aarch64/if-compare_2.c
new file mode 100644 (file)
index 0000000..14988ab
--- /dev/null
@@ -0,0 +1,96 @@
+/* { dg-do run } */
+/* { dg-additional-options "-O3 -std=c99 -save-temps" } */
+/* { dg-final { check-function-bodies "**" "" "" { target { le } } } } */
+
+#pragma GCC target "+nosve"
+
+#include <string.h>
+
+typedef int v4si __attribute__ ((vector_size (16)));
+
+/*
+**foo1:
+**     cmgt    v0.4s, v1.4s, v0.4s
+**     bsl     v0.16b, v2.16b, v3.16b
+**     ret
+*/
+v4si foo1 (v4si a, v4si b, v4si c, v4si d) {
+    return ((a < b) & c) | ((a >= b) & d);
+}
+
+/*
+**foo2:
+**     cmgt    v0.4s, v1.4s, v0.4s
+**     bsl     v0.16b, v3.16b, v2.16b
+**     ret
+*/
+v4si foo2 (v4si a, v4si b, v4si c, v4si d) {
+    return (~(a < b) & c) | (~(a >= b) & d);
+}
+
+
+/**
+**bar1:
+**...
+**     cmge    v[0-9]+.4s, v[0-9]+.4s, v[0-9]+.4s
+**     bsl     v[0-9]+.16b, v[0-9]+.16b, v[0-9]+.16b
+**     and     v[0-9]+.16b, v[0-9]+.16b, v[0-9]+.16b
+**...
+*/
+void bar1 (int * restrict a, int * restrict b, int * restrict c,
+         int * restrict d, int * restrict res, int n)
+{
+  for (int i = 0; i < (n & -4); i++)
+    res[i] = ((a[i] < b[i]) & c[i]) | ((a[i] >= b[i]) & d[i]);
+}
+
+/**
+**bar2:
+**...
+**     cmge    v[0-9]+.4s, v[0-9]+.4s, v[0-9]+.4s
+**     bsl     v[0-9]+.16b, v[0-9]+.16b, v[0-9]+.16b
+**...
+*/
+void bar2 (int * restrict a, int * restrict b, int * restrict c,
+         int * restrict d, int * restrict res, int n)
+{
+  for (int i = 0; i < (n & -4); i++)
+    res[i] = (-(a[i] < b[i]) & c[i]) | (-(a[i] >= b[i]) & d[i]);
+}
+
+extern void abort ();
+
+int main ()
+{
+
+  v4si a = { -3, -3, -3, -3 };
+  v4si b = { 3, 3, 3, 3 };
+  v4si c = { 5, 5, 5, 5 };
+  v4si d = { 8, 8, 8, 8 };
+
+  v4si res1 = foo1 (a, b, c, d);
+  if (memcmp (&res1, &c, 16UL) != 0)
+    abort ();
+
+  v4si res2 = foo2 (a, b, c, d);
+  if (memcmp (&res2, &d, 16UL) != 0)
+   abort ();
+
+  int ar[4] = { -3, -3, -3, -3 };
+  int br[4] = { 3, 3, 3, 3 };
+  int cr[4] = { 5, 5, 5, 5 };
+  int dr[4] = { 8, 8, 8, 8 };
+
+  int exp1[4] = { 1, 1, 1, 1 };
+  int res3[4];
+  bar1 ((int*)&ar, (int*)&br, (int*)&cr, (int*)&dr, (int*)&res3, 4);
+  if (memcmp (&res3, &exp1, 16UL) != 0)
+    abort ();
+
+  int res4[4];
+  bar2 ((int*)&ar, (int*)&br, (int*)&cr, (int*)&dr, (int*)&res4, 4);
+  if (memcmp (&res4, &cr, 16UL) != 0)
+    abort ();
+
+  return 0;
+}