]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
Improve memcmpeq for 512-bit vector with vpcmpeq + kortest.
authorliuhongt <hongtao.liu@intel.com>
Mon, 9 Oct 2023 07:07:54 +0000 (15:07 +0800)
committerliuhongt <hongtao.liu@intel.com>
Mon, 30 Oct 2023 03:10:01 +0000 (11:10 +0800)
When 2 vectors are equal, kmask is allones and kortest will set CF,
else CF will be cleared.

So CF bit can be used to check for the result of the comparison.

Before:
        vmovdqu (%rsi), %ymm0
        vpxorq  (%rdi), %ymm0, %ymm0
        vptest  %ymm0, %ymm0
        jne     .L2
        vmovdqu 32(%rsi), %ymm0
        vpxorq  32(%rdi), %ymm0, %ymm0
        vptest  %ymm0, %ymm0
        je      .L5
.L2:
        movl    $1, %eax
        xorl    $1, %eax
        vzeroupper
        ret

After:
        vmovdqu64       (%rsi), %zmm0
        xorl    %eax, %eax
        vpcmpeqd        (%rdi), %zmm0, %k0
        kortestw        %k0, %k0
        setc    %al
        vzeroupper
        ret

gcc/ChangeLog:

PR target/104610
* config/i386/i386-expand.cc (ix86_expand_branch): Handle
512-bit vector with vpcmpeq + kortest.
* config/i386/i386.md (cbranchxi4): New expander.
* config/i386/sse.md: (cbranch<mode>4): Extend to V16SImode
and V8DImode.

gcc/testsuite/ChangeLog:

* gcc.target/i386/pr104610-2.c: New test.

gcc/config/i386/i386-expand.cc
gcc/config/i386/i386.md
gcc/config/i386/sse.md
gcc/testsuite/gcc.target/i386/pr104610-2.c [new file with mode: 0644]

index 768053c480525bc25e73a5b35d56589b5fadee88..6ae5830037d280145ed562067d625c8b2d456101 100644 (file)
@@ -2413,30 +2413,53 @@ ix86_expand_branch (enum rtx_code code, rtx op0, rtx op1, rtx label)
   rtx tmp;
 
   /* Handle special case - vector comparsion with boolean result, transform
-     it using ptest instruction.  */
+     it using ptest instruction or vpcmpeq + kortest.  */
   if (GET_MODE_CLASS (mode) == MODE_VECTOR_INT
       || (mode == TImode && !TARGET_64BIT)
-      || mode == OImode)
+      || mode == OImode
+      || GET_MODE_SIZE (mode) == 64)
     {
-      rtx flag = gen_rtx_REG (CCZmode, FLAGS_REG);
-      machine_mode p_mode = GET_MODE_SIZE (mode) == 32 ? V4DImode : V2DImode;
+      unsigned msize = GET_MODE_SIZE (mode);
+      machine_mode p_mode
+       = msize == 64 ? V16SImode : msize == 32 ? V4DImode : V2DImode;
+      /* kortest set CF when result is 0xFFFF (op0 == op1).  */
+      rtx flag = gen_rtx_REG (msize == 64 ? CCCmode : CCZmode, FLAGS_REG);
 
       gcc_assert (code == EQ || code == NE);
 
-      if (GET_MODE_CLASS (mode) != MODE_VECTOR_INT)
+      /* Using vpcmpeq zmm zmm k + kortest for 512-bit vectors.  */
+      if (msize == 64)
        {
-         op0 = lowpart_subreg (p_mode, force_reg (mode, op0), mode);
-         op1 = lowpart_subreg (p_mode, force_reg (mode, op1), mode);
-         mode = p_mode;
+         if (mode != V16SImode)
+           {
+             op0 = lowpart_subreg (p_mode, force_reg (mode, op0), mode);
+             op1 = lowpart_subreg (p_mode, force_reg (mode, op1), mode);
+           }
+
+         tmp = gen_reg_rtx (HImode);
+         emit_insn (gen_avx512f_cmpv16si3 (tmp, op0, op1, GEN_INT (0)));
+         emit_insn (gen_kortesthi_ccc (tmp, tmp));
+       }
+      /* Using ptest for 128/256-bit vectors.  */
+      else
+       {
+         if (GET_MODE_CLASS (mode) != MODE_VECTOR_INT)
+           {
+             op0 = lowpart_subreg (p_mode, force_reg (mode, op0), mode);
+             op1 = lowpart_subreg (p_mode, force_reg (mode, op1), mode);
+             mode = p_mode;
+           }
+
+         /* Generate XOR since we can't check that one operand is zero
+            vector.  */
+         tmp = gen_reg_rtx (mode);
+         emit_insn (gen_rtx_SET (tmp, gen_rtx_XOR (mode, op0, op1)));
+         tmp = gen_lowpart (p_mode, tmp);
+         emit_insn (gen_rtx_SET (gen_rtx_REG (CCZmode, FLAGS_REG),
+                                 gen_rtx_UNSPEC (CCZmode,
+                                                 gen_rtvec (2, tmp, tmp),
+                                                 UNSPEC_PTEST)));
        }
-      /* Generate XOR since we can't check that one operand is zero vector.  */
-      tmp = gen_reg_rtx (mode);
-      emit_insn (gen_rtx_SET (tmp, gen_rtx_XOR (mode, op0, op1)));
-      tmp = gen_lowpart (p_mode, tmp);
-      emit_insn (gen_rtx_SET (gen_rtx_REG (CCZmode, FLAGS_REG),
-                             gen_rtx_UNSPEC (CCZmode,
-                                             gen_rtvec (2, tmp, tmp),
-                                             UNSPEC_PTEST)));
       tmp = gen_rtx_fmt_ee (code, VOIDmode, flag, const0_rtx);
       tmp = gen_rtx_IF_THEN_ELSE (VOIDmode, tmp,
                                  gen_rtx_LABEL_REF (VOIDmode, label),
index eb4121b3f1e621a33250f6a7d152e260573b8d48..92fbd57bae0fbdb14702e626b373f86592202109 100644 (file)
   DONE;
 })
 
+(define_expand "cbranchxi4"
+  [(set (reg:CC FLAGS_REG)
+       (compare:CC (match_operand:XI 1 "nonimmediate_operand")
+                   (match_operand:XI 2 "nonimmediate_operand")))
+   (set (pc) (if_then_else
+              (match_operator 0 "bt_comparison_operator"
+               [(reg:CC FLAGS_REG) (const_int 0)])
+              (label_ref (match_operand 3))
+              (pc)))]
+  "TARGET_AVX512F && TARGET_EVEX512 && !TARGET_PREFER_AVX256"
+{
+  ix86_expand_branch (GET_CODE (operands[0]),
+                     operands[1], operands[2], operands[3]);
+  DONE;
+})
+
 (define_expand "cstore<mode>4"
   [(set (reg:CC FLAGS_REG)
        (compare:CC (match_operand:SDWIM 2 "nonimmediate_operand")
index e2a7cbeb722a5c02bb5b87163405b64d6ee05bb1..906212fb4c143b6f008d98439df4ca6adbb958b2 100644 (file)
    (set_attr "type" "msklog")
    (set_attr "prefix" "vex")])
 
-(define_insn "kortest<mode>"
-  [(set (reg:CC FLAGS_REG)
-       (unspec:CC
+(define_insn "*kortest<mode>"
+  [(set (reg FLAGS_REG)
+       (unspec
          [(match_operand:SWI1248_AVX512BWDQ 0 "register_operand" "k")
           (match_operand:SWI1248_AVX512BWDQ 1 "register_operand" "k")]
          UNSPEC_KORTEST))]
    (set_attr "type" "msklog")
    (set_attr "prefix" "vex")])
 
+(define_insn "kortest<mode>_ccc"
+  [(set (reg:CCC FLAGS_REG)
+       (unspec:CCC
+         [(match_operand:SWI1248_AVX512BWDQ 0 "register_operand")
+          (match_operand:SWI1248_AVX512BWDQ 1 "register_operand")]
+         UNSPEC_KORTEST))]
+  "TARGET_AVX512F")
+
+(define_insn "kortest<mode>_ccz"
+  [(set (reg:CCZ FLAGS_REG)
+       (unspec:CCZ
+         [(match_operand:SWI1248_AVX512BWDQ 0 "register_operand")
+          (match_operand:SWI1248_AVX512BWDQ 1 "register_operand")]
+         UNSPEC_KORTEST))]
+  "TARGET_AVX512F")
+
+(define_expand "kortest<mode>"
+  [(set (reg:CC FLAGS_REG)
+       (unspec:CC
+         [(match_operand:SWI1248_AVX512BWDQ 0 "register_operand")
+          (match_operand:SWI1248_AVX512BWDQ 1 "register_operand")]
+         UNSPEC_KORTEST))]
+  "TARGET_AVX512F")
+
 (define_insn "kunpckhi"
   [(set (match_operand:HI 0 "register_operand" "=k")
        (ior:HI
 
 (define_expand "cbranch<mode>4"
   [(set (reg:CC FLAGS_REG)
-       (compare:CC (match_operand:VI48_AVX 1 "register_operand")
-                   (match_operand:VI48_AVX 2 "nonimmediate_operand")))
+       (compare:CC (match_operand:VI48_AVX_AVX512F 1 "register_operand")
+                   (match_operand:VI48_AVX_AVX512F 2 "nonimmediate_operand")))
    (set (pc) (if_then_else
               (match_operator 0 "bt_comparison_operator"
                [(reg:CC FLAGS_REG) (const_int 0)])
               (label_ref (match_operand 3))
               (pc)))]
-  "TARGET_SSE4_1"
+  "TARGET_SSE4_1 && (<MODE_SIZE> != 64 || !TARGET_PREFER_AVX256)"
 {
   ix86_expand_branch (GET_CODE (operands[0]),
                      operands[1], operands[2], operands[3]);
diff --git a/gcc/testsuite/gcc.target/i386/pr104610-2.c b/gcc/testsuite/gcc.target/i386/pr104610-2.c
new file mode 100644 (file)
index 0000000..999ef92
--- /dev/null
@@ -0,0 +1,14 @@
+/* { dg-do compile } */
+/* { dg-options "-mavx512f -O2 -mtune=generic" } */
+/* { dg-final { scan-assembler-times {(?n)vpcmpeq.*zmm} 2 } } */
+/* { dg-final { scan-assembler-times {(?n)kortest.*k[0-7]} 2 } } */
+
+int compare (const char* s1, const char* s2)
+{
+  return __builtin_memcmp (s1, s2, 64) == 0;
+}
+
+int compare1 (const char* s1, const char* s2)
+{
+  return __builtin_memcmp (s1, s2, 64) != 0;
+}