]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
RISC-V: Prepare dynamic LMUL heuristic for SLP.
authorRobin Dapp <rdapp@ventanamicro.com>
Mon, 21 Jul 2025 14:00:51 +0000 (16:00 +0200)
committerRobin Dapp <rdapp@ventanamicro.com>
Fri, 25 Jul 2025 13:07:41 +0000 (15:07 +0200)
This patch prepares the dynamic LMUL vector costing to use the coming
SLP_TREE_TYPE instead of the (to-be-removed) STMT_VINFO_TYPE.

Even though the whole approach should be reviewed and adjusted at some
point, the patch chooses the path of least resistance and uses a hash
map for the stmt_info -> slp node relationship.  A node is mapped to the
accompanying stmt_info during add_stmt_cost.  In finish_cost we go
through all statements as before, and obtain the corresponding slp nodes
as well as their types.

This allows us to operate largely as before.  We don't yet do the switch
over from STMT_VINFO_TYPE to SLP_TREE_TYPE, though but only take care
of the necessary refactoring upfront.

Regtested on rv64gcv_zvl512b with -mrvv-max-lmul=dynamic.  There are a
few regressions but nothing worse than what we already have.  I'd rather
accept these now and take it as an incentive to work on the heuristic
later than block the SLP work until it is fixed.

gcc/ChangeLog:

* config/riscv/riscv-vector-costs.cc (get_live_range):
Move compute_local_program_points to cost class.
(variable_vectorized_p): Add slp node parameter.
(need_additional_vector_vars_p): Move from here...
(costs::need_additional_vector_vars_p): ... to here and add slp
parameter.
(compute_estimated_lmul): Move update_local_live_ranges to cost
class.
(has_unexpected_spills_p): Move from here...
(costs::has_unexpected_spills_p): ... to here.
(costs::record_lmul_spills): New function.
(costs::add_stmt_cost): Add stmt_info, slp mapping.
(costs::finish_cost): Analyze loop.
* config/riscv/riscv-vector-costs.h: Move declarations to class.

gcc/config/riscv/riscv-vector-costs.cc
gcc/config/riscv/riscv-vector-costs.h

index 4d8170de9b2c86aa4a2e014432305fc676ced359..df924fafd8e562320a51a2c167784b52c0077996 100644 (file)
@@ -178,8 +178,8 @@ get_live_range (hash_map<tree, pair> *live_ranges, tree arg)
        STMT 5 (be vectorized)      -- point 2
        ...
 */
-static void
-compute_local_program_points (
+void
+costs::compute_local_program_points (
   vec_info *vinfo,
   hash_map<basic_block, vec<stmt_point>> &program_points_per_bb)
 {
@@ -274,14 +274,14 @@ loop_invariant_op_p (class loop *loop,
 
 /* Return true if the variable should be counted into liveness.  */
 static bool
-variable_vectorized_p (class loop *loop, stmt_vec_info stmt_info, tree var,
-                      bool lhs_p)
+variable_vectorized_p (class loop *loop, stmt_vec_info stmt_info,
+                      slp_tree node ATTRIBUTE_UNUSED, tree var, bool lhs_p)
 {
   if (!var)
     return false;
   gimple *stmt = STMT_VINFO_STMT (stmt_info);
-  enum stmt_vec_info_type type
-    = STMT_VINFO_TYPE (vect_stmt_to_vectorize (stmt_info));
+  stmt_info = vect_stmt_to_vectorize (stmt_info);
+  enum stmt_vec_info_type type = STMT_VINFO_TYPE (stmt_info);
   if (is_gimple_call (stmt) && gimple_call_internal_p (stmt))
     {
       if (gimple_call_internal_fn (stmt) == IFN_MASK_STORE
@@ -357,8 +357,8 @@ variable_vectorized_p (class loop *loop, stmt_vec_info stmt_info, tree var,
 
    The live range of SSA 1 is [1, 3] in bb 2.
    The live range of SSA 2 is [0, 4] in bb 3.  */
-static machine_mode
-compute_local_live_ranges (
+machine_mode
+costs::compute_local_live_ranges (
   loop_vec_info loop_vinfo,
   const hash_map<basic_block, vec<stmt_point>> &program_points_per_bb,
   hash_map<basic_block, hash_map<tree, pair>> &live_ranges_per_bb)
@@ -388,8 +388,11 @@ compute_local_live_ranges (
              unsigned int point = program_point.point;
              gimple *stmt = program_point.stmt;
              tree lhs = gimple_get_lhs (stmt);
-             if (variable_vectorized_p (loop, program_point.stmt_info, lhs,
-                                        true))
+             slp_tree *node = vinfo_slp_map.get (program_point.stmt_info);
+             if (!node)
+               continue;
+             if (variable_vectorized_p (loop, program_point.stmt_info,
+                                        *node, lhs, true))
                {
                  biggest_mode = get_biggest_mode (biggest_mode,
                                                   TYPE_MODE (TREE_TYPE (lhs)));
@@ -406,8 +409,8 @@ compute_local_live_ranges (
              for (i = 0; i < gimple_num_args (stmt); i++)
                {
                  tree var = gimple_arg (stmt, i);
-                 if (variable_vectorized_p (loop, program_point.stmt_info, var,
-                                            false))
+                 if (variable_vectorized_p (loop, program_point.stmt_info,
+                                            *node, var, false))
                    {
                      biggest_mode
                        = get_biggest_mode (biggest_mode,
@@ -597,11 +600,11 @@ get_store_value (gimple *stmt)
 }
 
 /* Return true if additional vector vars needed.  */
-static bool
-need_additional_vector_vars_p (stmt_vec_info stmt_info)
+bool
+costs::need_additional_vector_vars_p (stmt_vec_info stmt_info,
+                                     slp_tree node ATTRIBUTE_UNUSED)
 {
-  enum stmt_vec_info_type type
-    = STMT_VINFO_TYPE (vect_stmt_to_vectorize (stmt_info));
+  enum stmt_vec_info_type type = STMT_VINFO_TYPE (stmt_info);
   if (type == load_vec_info_type || type == store_vec_info_type)
     {
       if (STMT_VINFO_GATHER_SCATTER_P (stmt_info)
@@ -657,8 +660,8 @@ compute_estimated_lmul (loop_vec_info loop_vinfo, machine_mode mode)
 
    Then, after this function, we update SSA 1 live range in bb 2
    into [2, 4] since SSA 1 is live out into bb 3.  */
-static void
-update_local_live_ranges (
+void
+costs::update_local_live_ranges (
   vec_info *vinfo,
   hash_map<basic_block, vec<stmt_point>> &program_points_per_bb,
   hash_map<basic_block, hash_map<tree, pair>> &live_ranges_per_bb,
@@ -685,8 +688,13 @@ update_local_live_ranges (
        {
          gphi *phi = psi.phi ();
          stmt_vec_info stmt_info = vinfo->lookup_stmt (phi);
-         if (STMT_VINFO_TYPE (vect_stmt_to_vectorize (stmt_info))
-             == undef_vec_info_type)
+         stmt_info = vect_stmt_to_vectorize (stmt_info);
+         slp_tree *node = vinfo_slp_map.get (stmt_info);
+
+         if (!node)
+           continue;
+
+         if (STMT_VINFO_TYPE (stmt_info) == undef_vec_info_type)
            continue;
 
          for (j = 0; j < gimple_phi_num_args (phi); j++)
@@ -761,9 +769,12 @@ update_local_live_ranges (
          if (!is_gimple_assign_or_call (gsi_stmt (si)))
            continue;
          stmt_vec_info stmt_info = vinfo->lookup_stmt (gsi_stmt (si));
-         enum stmt_vec_info_type type
-           = STMT_VINFO_TYPE (vect_stmt_to_vectorize (stmt_info));
-         if (need_additional_vector_vars_p (stmt_info))
+         stmt_info = vect_stmt_to_vectorize (stmt_info);
+         slp_tree *node = vinfo_slp_map.get (stmt_info);
+         if (!node)
+           continue;
+         enum stmt_vec_info_type type = STMT_VINFO_TYPE (stmt_info);
+         if (need_additional_vector_vars_p (stmt_info, *node))
            {
              /* For non-adjacent load/store STMT, we will potentially
                 convert it into:
@@ -816,8 +827,8 @@ update_local_live_ranges (
 }
 
 /* Compute the maximum live V_REGS.  */
-static bool
-has_unexpected_spills_p (loop_vec_info loop_vinfo)
+bool
+costs::has_unexpected_spills_p (loop_vec_info loop_vinfo)
 {
   /* Compute local program points.
      It's a fast and effective computation.  */
@@ -899,7 +910,11 @@ costs::analyze_loop_vinfo (loop_vec_info loop_vinfo)
   /* Detect whether we're vectorizing for VLA and should apply the unrolling
      heuristic described above m_unrolled_vls_niters.  */
   record_potential_vls_unrolling (loop_vinfo);
+}
 
+void
+costs::record_lmul_spills (loop_vec_info loop_vinfo)
+{
   /* Detect whether the LOOP has unexpected spills.  */
   record_potential_unexpected_spills (loop_vinfo);
 }
@@ -1239,8 +1254,12 @@ costs::add_stmt_cost (int count, vect_cost_for_stmt kind,
   int stmt_cost
     = targetm.vectorize.builtin_vectorization_cost (kind, vectype, misalign);
 
+  if (stmt_info && node)
+    vinfo_slp_map.put (stmt_info, node);
+
   /* Do one-time initialization based on the vinfo.  */
   loop_vec_info loop_vinfo = dyn_cast<loop_vec_info> (m_vinfo);
+
   if (!m_analyzed_vinfo)
     {
       if (loop_vinfo)
@@ -1326,6 +1345,8 @@ costs::finish_cost (const vector_costs *scalar_costs)
 {
   if (loop_vec_info loop_vinfo = dyn_cast<loop_vec_info> (m_vinfo))
     {
+      record_lmul_spills (loop_vinfo);
+
       adjust_vect_cost_per_loop (loop_vinfo);
     }
   vector_costs::finish_cost (scalar_costs);
index de546a66f5cae0163226d29f8c6577423ea08da4..b84ceb1d3cf0bdd7b41a99ca85883d4b2423287f 100644 (file)
@@ -91,7 +91,10 @@ private:
   typedef pair_hash <tree_operand_hash, tree_operand_hash> tree_pair_hash;
   hash_set <tree_pair_hash> memrefs;
 
+  hash_map <stmt_vec_info, slp_tree> vinfo_slp_map;
+
   void analyze_loop_vinfo (loop_vec_info);
+  void record_lmul_spills (loop_vec_info loop_vinfo);
   void record_potential_vls_unrolling (loop_vec_info);
   bool prefer_unrolled_loop () const;
 
@@ -103,6 +106,19 @@ private:
   bool m_has_unexpected_spills_p = false;
   void record_potential_unexpected_spills (loop_vec_info);
 
+  void compute_local_program_points (vec_info *,
+                                    hash_map<basic_block, vec<stmt_point>> &);
+  void update_local_live_ranges (vec_info *,
+                                hash_map<basic_block, vec<stmt_point>> &,
+                                hash_map<basic_block, hash_map<tree, pair>> &,
+                                machine_mode *);
+  machine_mode compute_local_live_ranges
+    (loop_vec_info, const hash_map<basic_block, vec<stmt_point>> &,
+     hash_map<basic_block, hash_map<tree, pair>> &);
+
+  bool has_unexpected_spills_p (loop_vec_info);
+  bool need_additional_vector_vars_p (stmt_vec_info, slp_tree);
+
   void adjust_vect_cost_per_loop (loop_vec_info);
   unsigned adjust_stmt_cost (enum vect_cost_for_stmt kind,
                             loop_vec_info,