]> git.ipfire.org Git - thirdparty/gcc.git/blobdiff - gcc/config/riscv/riscv-vector-costs.cc
Update copyright years.
[thirdparty/gcc.git] / gcc / config / riscv / riscv-vector-costs.cc
index 946eb4a9fc64f7912d2f6891ab5a1dc6c4bac638..b9fdfdc5e3a95962c3c311ab69d954becd5f93f0 100644 (file)
@@ -1,5 +1,5 @@
 /* Cost model implementation for RISC-V 'V' Extension for GNU compiler.
-   Copyright (C) 2023-2023 Free Software Foundation, Inc.
+   Copyright (C) 2023-2024 Free Software Foundation, Inc.
    Contributed by Juzhe Zhong (juzhe.zhong@rivai.ai), RiVAI Technologies Ltd.
 
 This file is part of GCC.
@@ -88,6 +88,75 @@ namespace riscv_vector {
         3. M1(M8) -> MF2(M4) -> MF4(M2) -> MF8(M1)
 */
 
+static bool
+is_gimple_assign_or_call (gimple *stmt)
+{
+  return is_gimple_assign (stmt) || is_gimple_call (stmt);
+}
+
+/* Return the program point of 1st vectorized lanes statement.  */
+static unsigned int
+get_first_lane_point (const vec<stmt_point> program_points,
+                     stmt_vec_info stmt_info)
+{
+  for (const auto program_point : program_points)
+    if (program_point.stmt_info == DR_GROUP_FIRST_ELEMENT (stmt_info))
+      return program_point.point;
+  return 0;
+}
+
+/* Return the program point of last vectorized lanes statement.  */
+static unsigned int
+get_last_lane_point (const vec<stmt_point> program_points,
+                    stmt_vec_info stmt_info)
+{
+  unsigned int max_point = 0;
+  for (auto s = DR_GROUP_FIRST_ELEMENT (stmt_info); s != NULL;
+       s = DR_GROUP_NEXT_ELEMENT (s))
+    {
+      for (const auto program_point : program_points)
+       if (program_point.stmt_info == s && program_point.point > max_point)
+         max_point = program_point.point;
+    }
+  return max_point;
+}
+
+/* Return the last variable that is in the live range list.  */
+static pair *
+get_live_range (hash_map<tree, pair> *live_ranges, tree arg)
+{
+  auto *r = live_ranges->get (arg);
+  if (r)
+    return r;
+  else
+    {
+      tree t = arg;
+      gimple *def_stmt = NULL;
+      while (t && TREE_CODE (t) == SSA_NAME && !r
+            && (def_stmt = SSA_NAME_DEF_STMT (t)))
+       {
+         if (gimple_assign_cast_p (def_stmt))
+           {
+             t = gimple_assign_rhs1 (def_stmt);
+             r = live_ranges->get (t);
+             def_stmt = NULL;
+           }
+         else
+           /* FIXME: Currently we don't see any fold for
+              non-conversion statements.  */
+           t = NULL_TREE;
+       }
+      if (r)
+       return r;
+      else
+       {
+         bool insert_p = live_ranges->put (arg, pair (0, 0));
+         gcc_assert (!insert_p);
+         return live_ranges->get (arg);
+       }
+    }
+}
+
 /* Collect all STMTs that are vectorized and compute their program points.
    Note that we don't care about the STMTs that are not vectorized and
    we only build the local graph (within a block) of program points.
@@ -130,17 +199,16 @@ compute_local_program_points (
            dump_printf_loc (MSG_NOTE, vect_location,
                             "Compute local program points for bb %d:\n",
                             bb->index);
-         for (si = gsi_start_bb (bbs[i]); !gsi_end_p (si); gsi_next (&si))
+         for (si = gsi_start_bb (bb); !gsi_end_p (si); gsi_next (&si))
            {
-             if (!(is_gimple_assign (gsi_stmt (si))
-                   || is_gimple_call (gsi_stmt (si))))
+             if (!is_gimple_assign_or_call (gsi_stmt (si)))
                continue;
              stmt_vec_info stmt_info = vinfo->lookup_stmt (gsi_stmt (si));
              enum stmt_vec_info_type type
                = STMT_VINFO_TYPE (vect_stmt_to_vectorize (stmt_info));
              if (type != undef_vec_info_type)
                {
-                 stmt_point info = {point, gsi_stmt (si)};
+                 stmt_point info = {point, gsi_stmt (si), stmt_info};
                  program_points.safe_push (info);
                  point++;
                  if (dump_enabled_p ())
@@ -209,9 +277,13 @@ compute_local_live_ranges (
            {
              unsigned int point = program_point.point;
              gimple *stmt = program_point.stmt;
+             stmt_vec_info stmt_info = program_point.stmt_info;
              tree lhs = gimple_get_lhs (stmt);
+             enum stmt_vec_info_type type
+               = STMT_VINFO_TYPE (vect_stmt_to_vectorize (stmt_info));
              if (lhs != NULL_TREE && is_gimple_reg (lhs)
-                 && !POINTER_TYPE_P (TREE_TYPE (lhs)))
+                 && (!POINTER_TYPE_P (TREE_TYPE (lhs))
+                     || type != store_vec_info_type))
                {
                  biggest_mode = get_biggest_mode (biggest_mode,
                                                   TYPE_MODE (TREE_TYPE (lhs)));
@@ -219,6 +291,10 @@ compute_local_live_ranges (
                  pair &live_range
                    = live_ranges->get_or_insert (lhs, &existed_p);
                  gcc_assert (!existed_p);
+                 if (STMT_VINFO_MEMORY_ACCESS_TYPE (program_point.stmt_info)
+                     == VMAT_LOAD_STORE_LANES)
+                   point = get_first_lane_point (program_points,
+                                                 program_point.stmt_info);
                  live_range = pair (point, point);
                }
              for (i = 0; i < gimple_num_args (stmt); i++)
@@ -233,7 +309,8 @@ compute_local_live_ranges (
                     the future.  */
                  if (poly_int_tree_p (var)
                      || (is_gimple_val (var)
-                         && !POINTER_TYPE_P (TREE_TYPE (var))))
+                         && (!POINTER_TYPE_P (TREE_TYPE (var))
+                             || type != load_vec_info_type)))
                    {
                      biggest_mode
                        = get_biggest_mode (biggest_mode,
@@ -241,13 +318,38 @@ compute_local_live_ranges (
                      bool existed_p = false;
                      pair &live_range
                        = live_ranges->get_or_insert (var, &existed_p);
+                     if (STMT_VINFO_MEMORY_ACCESS_TYPE (
+                           program_point.stmt_info)
+                         == VMAT_LOAD_STORE_LANES)
+                       point = get_last_lane_point (program_points,
+                                                    program_point.stmt_info);
+                     else if (existed_p)
+                       point = MAX (live_range.second, point);
                      if (existed_p)
                        /* We will grow the live range for each use.  */
                        live_range = pair (live_range.first, point);
                      else
-                       /* We assume the variable is live from the start of
-                          this block.  */
-                       live_range = pair (0, point);
+                       {
+                         gimple *def_stmt;
+                         if (TREE_CODE (var) == SSA_NAME
+                             && (def_stmt = SSA_NAME_DEF_STMT (var))
+                             && gimple_bb (def_stmt) == bb
+                             && is_gimple_assign_or_call (def_stmt))
+                           {
+                             live_ranges->remove (var);
+                             for (unsigned int j = 0;
+                                  j < gimple_num_args (def_stmt); j++)
+                               {
+                                 tree arg = gimple_arg (def_stmt, j);
+                                 auto *r = get_live_range (live_ranges, arg);
+                                 gcc_assert (r);
+                                 (*r).second = MAX (point, (*r).second);
+                               }
+                           }
+                         else
+                           /* The splat vector lives the whole block.  */
+                           live_range = pair (0, program_points.length ());
+                       }
                    }
                }
            }
@@ -271,13 +373,17 @@ compute_local_live_ranges (
    E.g. If mode = SImode, biggest_mode = DImode, LMUL = M4.
        Then return RVVM4SImode (LMUL = 4, element mode = SImode).  */
 static unsigned int
-compute_nregs_for_mode (machine_mode mode, machine_mode biggest_mode, int lmul)
+compute_nregs_for_mode (loop_vec_info loop_vinfo, machine_mode mode,
+                       machine_mode biggest_mode, int lmul)
 {
+  unsigned int rgroup_size = LOOP_VINFO_LENS (loop_vinfo).is_empty ()
+                              ? 1
+                              : LOOP_VINFO_LENS (loop_vinfo).length ();
   unsigned int mode_size = GET_MODE_SIZE (mode).to_constant ();
   unsigned int biggest_size = GET_MODE_SIZE (biggest_mode).to_constant ();
   gcc_assert (biggest_size >= mode_size);
   unsigned int ratio = biggest_size / mode_size;
-  return lmul / ratio;
+  return MAX (lmul / ratio, 1) * rgroup_size;
 }
 
 /* This function helps to determine whether current LMUL will cause
@@ -291,7 +397,7 @@ compute_nregs_for_mode (machine_mode mode, machine_mode biggest_mode, int lmul)
        mode.
      - Third, Return the maximum V_REGs are alive of the loop.  */
 static unsigned int
-max_number_of_live_regs (const basic_block bb,
+max_number_of_live_regs (loop_vec_info loop_vinfo, const basic_block bb,
                         const hash_map<tree, pair> &live_ranges,
                         unsigned int max_point, machine_mode biggest_mode,
                         int lmul)
@@ -310,10 +416,13 @@ max_number_of_live_regs (const basic_block bb,
        {
          machine_mode mode = TYPE_MODE (TREE_TYPE (var));
          unsigned int nregs
-           = compute_nregs_for_mode (mode, biggest_mode, lmul);
+           = compute_nregs_for_mode (loop_vinfo, mode, biggest_mode, lmul);
          live_vars_vec[i] += nregs;
          if (live_vars_vec[i] > max_nregs)
-           max_nregs = live_vars_vec[i];
+           {
+             max_nregs = live_vars_vec[i];
+             live_point = i;
+           }
        }
     }
 
@@ -390,29 +499,38 @@ non_contiguous_memory_access_p (stmt_vec_info stmt_info)
 
 /* Return the LMUL of the current analysis.  */
 static int
-compute_estimated_lmul (loop_vec_info other_loop_vinfo, machine_mode mode)
+compute_estimated_lmul (loop_vec_info loop_vinfo, machine_mode mode)
 {
   gcc_assert (GET_MODE_BITSIZE (mode).is_constant ());
-  int regno_alignment
-    = riscv_get_v_regno_alignment (other_loop_vinfo->vector_mode);
-  if (known_eq (LOOP_VINFO_SLP_UNROLLING_FACTOR (other_loop_vinfo), 1U))
+  int regno_alignment = riscv_get_v_regno_alignment (loop_vinfo->vector_mode);
+  if (riscv_v_ext_vls_mode_p (loop_vinfo->vector_mode))
+    return regno_alignment;
+  else if (known_eq (LOOP_VINFO_SLP_UNROLLING_FACTOR (loop_vinfo), 1U))
     {
-      int estimated_vf = vect_vf_for_cost (other_loop_vinfo);
+      int estimated_vf = vect_vf_for_cost (loop_vinfo);
       return estimated_vf * GET_MODE_BITSIZE (mode).to_constant ()
             / TARGET_MIN_VLEN;
     }
-  else if (regno_alignment > 1)
-    return regno_alignment;
   else
     {
-      int ratio;
-      if (can_div_trunc_p (BYTES_PER_RISCV_VECTOR,
-                          LOOP_VINFO_SLP_UNROLLING_FACTOR (other_loop_vinfo),
-                          &ratio))
-       return TARGET_MAX_LMUL / ratio;
-      else
-       gcc_unreachable ();
+      /* Estimate the VLA SLP LMUL.  */
+      if (regno_alignment > RVV_M1)
+       return regno_alignment;
+      else if (mode != QImode
+              || LOOP_VINFO_SLP_UNROLLING_FACTOR (loop_vinfo).is_constant ())
+       {
+         int ratio;
+         if (can_div_trunc_p (BYTES_PER_RISCV_VECTOR,
+                              GET_MODE_SIZE (loop_vinfo->vector_mode), &ratio))
+           {
+             if (ratio == 1)
+               return RVV_M4;
+             else if (ratio == 2)
+               return RVV_M2;
+           }
+       }
     }
+  return 0;
 }
 
 /* Update the live ranges according PHI.
@@ -498,7 +616,7 @@ update_local_live_ranges (
                      auto &program_points = (*program_points_per_bb.get (bb));
                      if (program_points.is_empty ())
                        {
-                         stmt_point info = {1, phi};
+                         stmt_point info = {1, phi, stmt_info};
                          program_points.safe_push (info);
                        }
                      if (dump_enabled_p ())
@@ -536,13 +654,15 @@ update_local_live_ranges (
        }
       for (si = gsi_start_bb (bb); !gsi_end_p (si); gsi_next (&si))
        {
-         if (!(is_gimple_assign (gsi_stmt (si))
-               || is_gimple_call (gsi_stmt (si))))
+         if (!is_gimple_assign_or_call (gsi_stmt (si)))
            continue;
          stmt_vec_info stmt_info = vinfo->lookup_stmt (gsi_stmt (si));
          enum stmt_vec_info_type type
            = STMT_VINFO_TYPE (vect_stmt_to_vectorize (stmt_info));
-         if (non_contiguous_memory_access_p (stmt_info))
+         if (non_contiguous_memory_access_p (stmt_info)
+             /* LOAD_LANES/STORE_LANES doesn't need a perm indice.  */
+             && STMT_VINFO_MEMORY_ACCESS_TYPE (stmt_info)
+                  != VMAT_LOAD_STORE_LANES)
            {
              /* For non-adjacent load/store STMT, we will potentially
                 convert it into:
@@ -571,19 +691,37 @@ update_local_live_ranges (
                dump_printf_loc (MSG_NOTE, vect_location,
                                 "Add perm indice %T, start = 0, end = %d\n",
                                 sel, max_point);
+             if (!LOOP_VINFO_LENS (loop_vinfo).is_empty ()
+                 && LOOP_VINFO_LENS (loop_vinfo).length () > 1)
+               {
+                 /* If we are vectorizing a permutation when the rgroup number
+                    > 1, we will need additional mask to shuffle the second
+                    vector.  */
+                 tree mask = build_decl (UNKNOWN_LOCATION, VAR_DECL,
+                                         get_identifier ("vect_perm_mask"),
+                                         boolean_type_node);
+                 pair &live_range
+                   = live_ranges->get_or_insert (mask, &existed_p);
+                 gcc_assert (!existed_p);
+                 live_range = pair (0, max_point);
+                 if (dump_enabled_p ())
+                   dump_printf_loc (MSG_NOTE, vect_location,
+                                    "Add perm mask %T, start = 0, end = %d\n",
+                                    mask, max_point);
+               }
            }
        }
     }
 }
 
-/* Return true that the LMUL of new COST model is preferred.  */
+/* Compute the maximum live V_REGS.  */
 static bool
-preferred_new_lmul_p (loop_vec_info other_loop_vinfo)
+has_unexpected_spills_p (loop_vec_info loop_vinfo)
 {
   /* Compute local program points.
      It's a fast and effective computation.  */
   hash_map<basic_block, vec<stmt_point>> program_points_per_bb;
-  compute_local_program_points (other_loop_vinfo, program_points_per_bb);
+  compute_local_program_points (loop_vinfo, program_points_per_bb);
 
   /* Compute local live ranges.  */
   hash_map<basic_block, hash_map<tree, pair>> live_ranges_per_bb;
@@ -591,34 +729,38 @@ preferred_new_lmul_p (loop_vec_info other_loop_vinfo)
     = compute_local_live_ranges (program_points_per_bb, live_ranges_per_bb);
 
   /* Update live ranges according to PHI.  */
-  update_local_live_ranges (other_loop_vinfo, program_points_per_bb,
+  update_local_live_ranges (loop_vinfo, program_points_per_bb,
                            live_ranges_per_bb, &biggest_mode);
 
-  int lmul = compute_estimated_lmul (other_loop_vinfo, biggest_mode);
+  int lmul = compute_estimated_lmul (loop_vinfo, biggest_mode);
   /* TODO: We calculate the maximum live vars base on current STMTS
      sequence.  We can support live range shrink if it can give us
      big improvement in the future.  */
-  if (!live_ranges_per_bb.is_empty ())
+  if (lmul > RVV_M1)
     {
-      unsigned int max_nregs = 0;
-      for (hash_map<basic_block, hash_map<tree, pair>>::iterator iter
-          = live_ranges_per_bb.begin ();
-          iter != live_ranges_per_bb.end (); ++iter)
+      if (!live_ranges_per_bb.is_empty ())
        {
-         basic_block bb = (*iter).first;
-         unsigned int max_point
-           = (*program_points_per_bb.get (bb)).length () + 1;
-         if ((*iter).second.is_empty ())
-           continue;
-         /* We prefer larger LMUL unless it causes register spillings.  */
-         unsigned int nregs
-           = max_number_of_live_regs (bb, (*iter).second, max_point,
-                                      biggest_mode, lmul);
-         if (nregs > max_nregs)
-           max_nregs = nregs;
+         unsigned int max_nregs = 0;
+         for (hash_map<basic_block, hash_map<tree, pair>>::iterator iter
+              = live_ranges_per_bb.begin ();
+              iter != live_ranges_per_bb.end (); ++iter)
+           {
+             basic_block bb = (*iter).first;
+             unsigned int max_point
+               = (*program_points_per_bb.get (bb)).length () + 1;
+             if ((*iter).second.is_empty ())
+               continue;
+             /* We prefer larger LMUL unless it causes register spillings. */
+             unsigned int nregs
+               = max_number_of_live_regs (loop_vinfo, bb, (*iter).second,
+                                          max_point, biggest_mode, lmul);
+             if (nregs > max_nregs)
+               max_nregs = nregs;
+           }
+         live_ranges_per_bb.empty ();
+         if (max_nregs > V_REG_NUM)
+           return true;
        }
-      live_ranges_per_bb.empty ();
-      return max_nregs > V_REG_NUM;
     }
   if (!program_points_per_bb.is_empty ())
     {
@@ -632,7 +774,7 @@ preferred_new_lmul_p (loop_vec_info other_loop_vinfo)
        }
       program_points_per_bb.empty ();
     }
-  return lmul > RVV_M1;
+  return false;
 }
 
 costs::costs (vec_info *vinfo, bool costing_for_scalar)
@@ -667,6 +809,30 @@ costs::analyze_loop_vinfo (loop_vec_info loop_vinfo)
   /* Detect whether we're vectorizing for VLA and should apply the unrolling
      heuristic described above m_unrolled_vls_niters.  */
   record_potential_vls_unrolling (loop_vinfo);
+
+  /* Detect whether the LOOP has unexpected spills.  */
+  record_potential_unexpected_spills (loop_vinfo);
+}
+
+/* Analyze the vectorized program stataments and use dynamic LMUL
+   heuristic to detect whether the loop has unexpected spills.  */
+void
+costs::record_potential_unexpected_spills (loop_vec_info loop_vinfo)
+{
+  /* We only want to apply the heuristic if LOOP_VINFO is being
+     vectorized for VLA and known NITERS VLS loop.  */
+  if (riscv_autovec_lmul == RVV_DYNAMIC
+      && (m_cost_type == VLA_VECTOR_COST
+         || (m_cost_type == VLS_VECTOR_COST
+             && LOOP_VINFO_NITERS_KNOWN_P (loop_vinfo))))
+    {
+      bool post_dom_available_p = dom_info_available_p (CDI_POST_DOMINATORS);
+      if (!post_dom_available_p)
+       calculate_dominance_info (CDI_POST_DOMINATORS);
+      m_has_unexpected_spills_p = has_unexpected_spills_p (loop_vinfo);
+      if (!post_dom_available_p)
+       free_dominance_info (CDI_POST_DOMINATORS);
+    }
 }
 
 /* Decide whether to use the unrolling heuristic described above
@@ -762,19 +928,39 @@ costs::better_main_loop_than_p (const vector_costs *uncast_other) const
          return other_prefer_unrolled;
        }
     }
-
-  if (!LOOP_VINFO_NITERS_KNOWN_P (this_loop_vinfo)
-      && riscv_autovec_lmul == RVV_DYNAMIC)
+  else if (riscv_autovec_lmul == RVV_DYNAMIC)
     {
-      if (!riscv_v_ext_vector_mode_p (this_loop_vinfo->vector_mode))
-       return false;
-      bool post_dom_available_p = dom_info_available_p (CDI_POST_DOMINATORS);
-      if (!post_dom_available_p)
-       calculate_dominance_info (CDI_POST_DOMINATORS);
-      bool preferred_p = preferred_new_lmul_p (other_loop_vinfo);
-      if (!post_dom_available_p)
-       free_dominance_info (CDI_POST_DOMINATORS);
-      return preferred_p;
+      if (other->m_has_unexpected_spills_p)
+       {
+         if (dump_enabled_p ())
+           dump_printf_loc (MSG_NOTE, vect_location,
+                            "Preferring smaller LMUL loop because"
+                            " it has unexpected spills\n");
+         return true;
+       }
+      else if (riscv_v_ext_vector_mode_p (other_loop_vinfo->vector_mode))
+       {
+         if (LOOP_VINFO_NITERS_KNOWN_P (other_loop_vinfo))
+           {
+             if (maybe_gt (LOOP_VINFO_INT_NITERS (this_loop_vinfo),
+                           LOOP_VINFO_VECT_FACTOR (this_loop_vinfo)))
+               {
+                 if (dump_enabled_p ())
+                   dump_printf_loc (MSG_NOTE, vect_location,
+                                    "Keep current LMUL loop because"
+                                    " known NITERS exceed the new VF\n");
+                 return false;
+               }
+           }
+         else
+           {
+             if (dump_enabled_p ())
+               dump_printf_loc (MSG_NOTE, vect_location,
+                                "Keep current LMUL loop because"
+                                " it is unknown NITERS\n");
+             return false;
+           }
+       }
     }
 
   return vector_costs::better_main_loop_than_p (other);