]> git.ipfire.org Git - thirdparty/gcc.git/blob - gcc/config/riscv/riscv-vector-costs.cc
RISC-V: Support Dynamic LMUL Cost model
[thirdparty/gcc.git] / gcc / config / riscv / riscv-vector-costs.cc
1 /* Cost model implementation for RISC-V 'V' Extension for GNU compiler.
2 Copyright (C) 2023-2023 Free Software Foundation, Inc.
3 Contributed by Juzhe Zhong (juzhe.zhong@rivai.ai), RiVAI Technologies Ltd.
4
5 This file is part of GCC.
6
7 GCC is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 3, or (at your option)
10 any later version.
11
12 GCC is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with GCC; see the file COPYING3. If not see
19 <http://www.gnu.org/licenses/>. */
20
21 #define IN_TARGET_CODE 1
22
23 #define INCLUDE_STRING
24 #include "config.h"
25 #include "system.h"
26 #include "coretypes.h"
27 #include "tm.h"
28 #include "target.h"
29 #include "function.h"
30 #include "tree.h"
31 #include "basic-block.h"
32 #include "rtl.h"
33 #include "gimple.h"
34 #include "targhooks.h"
35 #include "cfgloop.h"
36 #include "fold-const.h"
37 #include "tm_p.h"
38 #include "tree-vectorizer.h"
39 #include "gimple-iterator.h"
40 #include "bitmap.h"
41 #include "ssa.h"
42 #include "backend.h"
43
44 /* This file should be included last. */
45 #include "riscv-vector-costs.h"
46
47 namespace riscv_vector {
48
49 /* Dynamic LMUL philosophy - Local linear-scan SSA live range based analysis
50 determine LMUL
51
52 - Collect all vectorize STMTs locally for each loop block.
53 - Build program point based graph, ignore non-vectorize STMTs:
54
55 vectorize STMT 0 - point 0
56 scalar STMT 0 - ignore.
57 vectorize STMT 1 - point 1
58 ...
59 - Compute the number of live V_REGs live at each program point
60 - Determine LMUL in VECTOR COST model according to the program point
61 which has maximum live V_REGs.
62
63 Note:
64
65 - BIGGEST_MODE is the biggest LMUL auto-vectorization element mode.
66 It's important for mixed size auto-vectorization (Conversions, ... etc).
67 E.g. For a loop that is vectorizing conversion of INT32 -> INT64.
68 The biggest mode is DImode and LMUL = 8, LMUL = 4 for SImode.
69 We compute the number live V_REGs at each program point according to
70 this information.
71 - We only compute program points and live ranges locally (within a block)
72 since we just need to compute the number of live V_REGs at each program
73 point and we are not really allocating the registers for each SSA.
74 We can make the variable has another local live range in another block
75 if it live out/live in to another block. Such approach doesn't affect
76 out accurate live range analysis.
77 - Current analysis didn't consider any instruction scheduling which
78 may improve the register pressure. So we are conservatively doing the
79 analysis which may end up with smaller LMUL.
80 TODO: Maybe we could support a reasonable live range shrink algorithm
81 which take advantage of instruction scheduling.
82 - We may have these following possible autovec modes analysis:
83
84 1. M8 -> M4 -> M2 -> M1 (stop analysis here) -> MF2 -> MF4 -> MF8
85 2. M8 -> M1(M4) -> MF2(M2) -> MF4(M1) (stop analysis here) -> MF8(MF2)
86 3. M1(M8) -> MF2(M4) -> MF4(M2) -> MF8(M1)
87 */
88 static hash_map<class loop *, autovec_info> loop_autovec_infos;
89
90 /* Collect all STMTs that are vectorized and compute their program points.
91 Note that we don't care about the STMTs that are not vectorized and
92 we only build the local graph (within a block) of program points.
93
94 Loop:
95 bb 2:
96 STMT 1 (be vectorized) -- point 0
97 STMT 2 (not be vectorized) -- ignored
98 STMT 3 (be vectorized) -- point 1
99 STMT 4 (be vectorized) -- point 2
100 STMT 5 (be vectorized) -- point 3
101 ...
102 bb 3:
103 STMT 1 (be vectorized) -- point 0
104 STMT 2 (be vectorized) -- point 1
105 STMT 3 (not be vectorized) -- ignored
106 STMT 4 (not be vectorized) -- ignored
107 STMT 5 (be vectorized) -- point 2
108 ...
109 */
110 static void
111 compute_local_program_points (
112 vec_info *vinfo,
113 hash_map<basic_block, vec<stmt_point>> &program_points_per_bb)
114 {
115 if (loop_vec_info loop_vinfo = dyn_cast<loop_vec_info> (vinfo))
116 {
117 class loop *loop = LOOP_VINFO_LOOP (loop_vinfo);
118 basic_block *bbs = LOOP_VINFO_BBS (loop_vinfo);
119 unsigned int nbbs = loop->num_nodes;
120 gimple_stmt_iterator si;
121 unsigned int i;
122 /* Collect the stmts that is vectorized and mark their program point. */
123 for (i = 0; i < nbbs; i++)
124 {
125 int point = 0;
126 basic_block bb = bbs[i];
127 vec<stmt_point> program_points = vNULL;
128 if (dump_enabled_p ())
129 dump_printf_loc (MSG_NOTE, vect_location,
130 "Compute local program points for bb %d:\n",
131 bb->index);
132 for (si = gsi_start_bb (bbs[i]); !gsi_end_p (si); gsi_next (&si))
133 {
134 if (!(is_gimple_assign (gsi_stmt (si))
135 || is_gimple_call (gsi_stmt (si))))
136 continue;
137 stmt_vec_info stmt_info = vinfo->lookup_stmt (gsi_stmt (si));
138 if (STMT_VINFO_TYPE (vect_stmt_to_vectorize (stmt_info))
139 != undef_vec_info_type)
140 {
141 stmt_point info = {point, gsi_stmt (si)};
142 program_points.safe_push (info);
143 point++;
144 if (dump_enabled_p ())
145 dump_printf_loc (MSG_NOTE, vect_location,
146 "program point %d: %G", info.point,
147 gsi_stmt (si));
148 }
149 }
150 program_points_per_bb.put (bb, program_points);
151 }
152 }
153 }
154
155 /* Compute local live ranges of each vectorized variable.
156 Note that we only compute local live ranges (within a block) since
157 local live ranges information is accurate enough for us to determine
158 the LMUL/vectorization factor of the loop.
159
160 Loop:
161 bb 2:
162 STMT 1 -- point 0
163 STMT 2 (def SSA 1) -- point 1
164 STMT 3 (use SSA 1) -- point 2
165 STMT 4 -- point 3
166 bb 3:
167 STMT 1 -- point 0
168 STMT 2 -- point 1
169 STMT 3 -- point 2
170 STMT 4 (use SSA 2) -- point 3
171
172 The live range of SSA 1 is [1, 3] in bb 2.
173 The live range of SSA 2 is [0, 4] in bb 3. */
174 static machine_mode
175 compute_local_live_ranges (
176 const hash_map<basic_block, vec<stmt_point>> &program_points_per_bb,
177 hash_map<basic_block, hash_map<tree, pair>> &live_ranges_per_bb)
178 {
179 machine_mode biggest_mode = QImode;
180 if (!program_points_per_bb.is_empty ())
181 {
182 auto_vec<tree> visited_vars;
183 unsigned int i;
184 for (hash_map<basic_block, vec<stmt_point>>::iterator iter
185 = program_points_per_bb.begin ();
186 iter != program_points_per_bb.end (); ++iter)
187 {
188 basic_block bb = (*iter).first;
189 vec<stmt_point> program_points = (*iter).second;
190 bool existed_p = false;
191 hash_map<tree, pair> *live_ranges
192 = &live_ranges_per_bb.get_or_insert (bb, &existed_p);
193 gcc_assert (!existed_p);
194 if (dump_enabled_p ())
195 dump_printf_loc (MSG_NOTE, vect_location,
196 "Compute local live ranges for bb %d:\n",
197 bb->index);
198 for (const auto program_point : program_points)
199 {
200 unsigned int point = program_point.point;
201 gimple *stmt = program_point.stmt;
202 machine_mode mode = biggest_mode;
203 tree lhs = gimple_get_lhs (stmt);
204 if (lhs != NULL_TREE && is_gimple_reg (lhs)
205 && !POINTER_TYPE_P (TREE_TYPE (lhs)))
206 {
207 mode = TYPE_MODE (TREE_TYPE (lhs));
208 bool existed_p = false;
209 pair &live_range
210 = live_ranges->get_or_insert (lhs, &existed_p);
211 gcc_assert (!existed_p);
212 live_range = pair (point, point);
213 }
214 for (i = 0; i < gimple_num_args (stmt); i++)
215 {
216 tree var = gimple_arg (stmt, i);
217 /* Both IMM and REG are included since a VECTOR_CST may be
218 potentially held in a vector register. However, it's not
219 accurate, since a PLUS_EXPR can be vectorized into vadd.vi
220 if IMM is -16 ~ 15.
221
222 TODO: We may elide the cases that the unnecessary IMM in
223 the future. */
224 if (is_gimple_val (var) && !POINTER_TYPE_P (TREE_TYPE (var)))
225 {
226 mode = TYPE_MODE (TREE_TYPE (var));
227 bool existed_p = false;
228 pair &live_range
229 = live_ranges->get_or_insert (var, &existed_p);
230 if (existed_p)
231 /* We will grow the live range for each use. */
232 live_range = pair (live_range.first, point);
233 else
234 /* We assume the variable is live from the start of
235 this block. */
236 live_range = pair (0, point);
237 }
238 }
239 if (GET_MODE_SIZE (mode).to_constant ()
240 > GET_MODE_SIZE (biggest_mode).to_constant ())
241 biggest_mode = mode;
242 }
243 if (dump_enabled_p ())
244 for (hash_map<tree, pair>::iterator iter = live_ranges->begin ();
245 iter != live_ranges->end (); ++iter)
246 dump_printf_loc (MSG_NOTE, vect_location,
247 "%T: type = %T, start = %d, end = %d\n",
248 (*iter).first, TREE_TYPE ((*iter).first),
249 (*iter).second.first, (*iter).second.second);
250 }
251 }
252 if (dump_enabled_p ())
253 dump_printf_loc (MSG_NOTE, vect_location, "Biggest mode = %s\n",
254 GET_MODE_NAME (biggest_mode));
255 return biggest_mode;
256 }
257
258 /* Compute the mode for MODE, BIGGEST_MODE and LMUL.
259
260 E.g. If mode = SImode, biggest_mode = DImode, LMUL = M4.
261 Then return RVVM4SImode (LMUL = 4, element mode = SImode). */
262 static unsigned int
263 compute_nregs_for_mode (machine_mode mode, machine_mode biggest_mode, int lmul)
264 {
265 unsigned int mode_size = GET_MODE_SIZE (mode).to_constant ();
266 unsigned int biggest_size = GET_MODE_SIZE (biggest_mode).to_constant ();
267 gcc_assert (biggest_size >= mode_size);
268 unsigned int ratio = biggest_size / mode_size;
269 return lmul / ratio;
270 }
271
272 /* This function helps to determine whether current LMUL will cause
273 potential vector register (V_REG) spillings according to live range
274 information.
275
276 - First, compute how many variable are alive of each program point
277 in each bb of the loop.
278 - Second, compute how many V_REGs are alive of each program point
279 in each bb of the loop according the BIGGEST_MODE and the variable
280 mode.
281 - Third, Return the maximum V_REGs are alive of the loop. */
282 static unsigned int
283 max_number_of_live_regs (const basic_block bb,
284 const hash_map<tree, pair> &live_ranges,
285 unsigned int max_point, machine_mode biggest_mode,
286 int lmul)
287 {
288 unsigned int max_nregs = 0;
289 unsigned int i;
290 unsigned int live_point = 0;
291 auto_vec<unsigned int> live_vars_vec;
292 live_vars_vec.safe_grow (max_point + 1, true);
293 for (i = 0; i < live_vars_vec.length (); ++i)
294 live_vars_vec[i] = 0;
295 for (hash_map<tree, pair>::iterator iter = live_ranges.begin ();
296 iter != live_ranges.end (); ++iter)
297 {
298 tree var = (*iter).first;
299 pair live_range = (*iter).second;
300 for (i = live_range.first; i <= live_range.second; i++)
301 {
302 machine_mode mode = TYPE_MODE (TREE_TYPE (var));
303 unsigned int nregs
304 = compute_nregs_for_mode (mode, biggest_mode, lmul);
305 live_vars_vec[i] += nregs;
306 if (live_vars_vec[i] > max_nregs)
307 max_nregs = live_vars_vec[i];
308 }
309 }
310
311 /* Collect user explicit RVV type. */
312 auto_vec<basic_block> all_preds
313 = get_all_dominated_blocks (CDI_POST_DOMINATORS, bb);
314 tree t;
315 FOR_EACH_SSA_NAME (i, t, cfun)
316 {
317 machine_mode mode = TYPE_MODE (TREE_TYPE (t));
318 if (!lookup_vector_type_attribute (TREE_TYPE (t))
319 && !riscv_v_ext_vls_mode_p (mode))
320 continue;
321
322 gimple *def = SSA_NAME_DEF_STMT (t);
323 if (gimple_bb (def) && !all_preds.contains (gimple_bb (def)))
324 continue;
325 use_operand_p use_p;
326 imm_use_iterator iterator;
327
328 FOR_EACH_IMM_USE_FAST (use_p, iterator, t)
329 {
330 if (!USE_STMT (use_p) || is_gimple_debug (USE_STMT (use_p))
331 || !dominated_by_p (CDI_POST_DOMINATORS, bb,
332 gimple_bb (USE_STMT (use_p))))
333 continue;
334
335 int regno_alignment = riscv_get_v_regno_alignment (mode);
336 max_nregs += regno_alignment;
337 if (dump_enabled_p ())
338 dump_printf_loc (
339 MSG_NOTE, vect_location,
340 "Explicit used SSA %T, vectype = %T, mode = %s, cause %d "
341 "V_REG live in bb %d at program point %d\n",
342 t, TREE_TYPE (t), GET_MODE_NAME (mode), regno_alignment,
343 bb->index, live_point);
344 break;
345 }
346 }
347
348 if (dump_enabled_p ())
349 dump_printf_loc (MSG_NOTE, vect_location,
350 "Maximum lmul = %d, %d number of live V_REG at program "
351 "point %d for bb %d\n",
352 lmul, max_nregs, live_point, bb->index);
353 return max_nregs;
354 }
355
356 /* Return the LMUL of the current analysis. */
357 static int
358 get_current_lmul (class loop *loop)
359 {
360 return loop_autovec_infos.get (loop)->current_lmul;
361 }
362
363 /* Update the live ranges according PHI.
364
365 Loop:
366 bb 2:
367 STMT 1 -- point 0
368 STMT 2 (def SSA 1) -- point 1
369 STMT 3 (use SSA 1) -- point 2
370 STMT 4 -- point 3
371 bb 3:
372 SSA 2 = PHI<SSA 1>
373 STMT 1 -- point 0
374 STMT 2 -- point 1
375 STMT 3 (use SSA 2) -- point 2
376 STMT 4 -- point 3
377
378 Before this function, the SSA 1 live range is [2, 3] in bb 2
379 and SSA 2 is [0, 3] in bb 3.
380
381 Then, after this function, we update SSA 1 live range in bb 2
382 into [2, 4] since SSA 1 is live out into bb 3. */
383 static void
384 update_local_live_ranges (
385 vec_info *vinfo,
386 hash_map<basic_block, vec<stmt_point>> &program_points_per_bb,
387 hash_map<basic_block, hash_map<tree, pair>> &live_ranges_per_bb)
388 {
389 loop_vec_info loop_vinfo = dyn_cast<loop_vec_info> (vinfo);
390 if (!loop_vinfo)
391 return;
392
393 class loop *loop = LOOP_VINFO_LOOP (loop_vinfo);
394 basic_block *bbs = LOOP_VINFO_BBS (loop_vinfo);
395 unsigned int nbbs = loop->num_nodes;
396 unsigned int i, j;
397 gphi_iterator psi;
398 for (i = 0; i < nbbs; i++)
399 {
400 basic_block bb = bbs[i];
401 if (dump_enabled_p ())
402 dump_printf_loc (MSG_NOTE, vect_location,
403 "Update local program points for bb %d:\n", bb->index);
404 for (psi = gsi_start_phis (bbs[i]); !gsi_end_p (psi); gsi_next (&psi))
405 {
406 gphi *phi = psi.phi ();
407 stmt_vec_info stmt_info = vinfo->lookup_stmt (phi);
408 if (STMT_VINFO_TYPE (vect_stmt_to_vectorize (stmt_info))
409 == undef_vec_info_type)
410 continue;
411
412 for (j = 0; j < gimple_phi_num_args (phi); j++)
413 {
414 edge e = gimple_phi_arg_edge (phi, j);
415 tree def = gimple_phi_arg_def (phi, j);
416 auto *live_ranges = live_ranges_per_bb.get (e->src);
417 if (!program_points_per_bb.get (e->src))
418 continue;
419 unsigned int max_point
420 = (*program_points_per_bb.get (e->src)).length () - 1;
421 auto *live_range = live_ranges->get (def);
422 if (!live_range)
423 continue;
424
425 unsigned int end = (*live_range).second;
426 (*live_range).second = max_point;
427 if (dump_enabled_p ())
428 dump_printf_loc (MSG_NOTE, vect_location,
429 "Update %T end point from %d to %d:\n", def,
430 end, (*live_range).second);
431 }
432 }
433 }
434 }
435
436 costs::costs (vec_info *vinfo, bool costing_for_scalar)
437 : vector_costs (vinfo, costing_for_scalar)
438 {}
439
440 /* Return true that the LMUL of new COST model is preferred. */
441 bool
442 costs::preferred_new_lmul_p (const vector_costs *uncast_other) const
443 {
444 auto other = static_cast<const costs *> (uncast_other);
445 auto this_loop_vinfo = as_a<loop_vec_info> (this->m_vinfo);
446 auto other_loop_vinfo = as_a<loop_vec_info> (other->m_vinfo);
447 class loop *loop = LOOP_VINFO_LOOP (this_loop_vinfo);
448
449 if (!LOOP_VINFO_CAN_USE_PARTIAL_VECTORS_P (this_loop_vinfo)
450 && LOOP_VINFO_CAN_USE_PARTIAL_VECTORS_P (other_loop_vinfo))
451 return false;
452
453 if (loop_autovec_infos.get (loop) && loop_autovec_infos.get (loop)->end_p)
454 return false;
455 else if (loop_autovec_infos.get (loop))
456 loop_autovec_infos.get (loop)->current_lmul
457 = loop_autovec_infos.get (loop)->current_lmul / 2;
458 else
459 {
460 int regno_alignment
461 = riscv_get_v_regno_alignment (other_loop_vinfo->vector_mode);
462 if (known_eq (LOOP_VINFO_SLP_UNROLLING_FACTOR (other_loop_vinfo), 1U))
463 regno_alignment = RVV_M8;
464 loop_autovec_infos.put (loop, {regno_alignment, regno_alignment, false});
465 }
466
467 int lmul = get_current_lmul (loop);
468 if (dump_enabled_p ())
469 dump_printf_loc (MSG_NOTE, vect_location,
470 "Comparing two main loops (%s at VF %d vs %s at VF %d)\n",
471 GET_MODE_NAME (this_loop_vinfo->vector_mode),
472 vect_vf_for_cost (this_loop_vinfo),
473 GET_MODE_NAME (other_loop_vinfo->vector_mode),
474 vect_vf_for_cost (other_loop_vinfo));
475
476 /* Compute local program points.
477 It's a fast and effective computation. */
478 hash_map<basic_block, vec<stmt_point>> program_points_per_bb;
479 compute_local_program_points (other->m_vinfo, program_points_per_bb);
480
481 /* Compute local live ranges. */
482 hash_map<basic_block, hash_map<tree, pair>> live_ranges_per_bb;
483 machine_mode biggest_mode
484 = compute_local_live_ranges (program_points_per_bb, live_ranges_per_bb);
485
486 /* Update live ranges according to PHI. */
487 update_local_live_ranges (other->m_vinfo, program_points_per_bb,
488 live_ranges_per_bb);
489
490 /* TODO: We calculate the maximum live vars base on current STMTS
491 sequence. We can support live range shrink if it can give us
492 big improvement in the future. */
493 if (!live_ranges_per_bb.is_empty ())
494 {
495 unsigned int max_nregs = 0;
496 for (hash_map<basic_block, hash_map<tree, pair>>::iterator iter
497 = live_ranges_per_bb.begin ();
498 iter != live_ranges_per_bb.end (); ++iter)
499 {
500 basic_block bb = (*iter).first;
501 unsigned int max_point
502 = (*program_points_per_bb.get (bb)).length () - 1;
503 if ((*iter).second.is_empty ())
504 continue;
505 /* We prefer larger LMUL unless it causes register spillings. */
506 unsigned int nregs
507 = max_number_of_live_regs (bb, (*iter).second, max_point,
508 biggest_mode, lmul);
509 if (nregs > max_nregs)
510 max_nregs = nregs;
511 live_ranges_per_bb.empty ();
512 }
513 live_ranges_per_bb.empty ();
514 if (loop_autovec_infos.get (loop)->current_lmul == RVV_M1
515 || max_nregs <= V_REG_NUM)
516 loop_autovec_infos.get (loop)->end_p = true;
517 if (loop_autovec_infos.get (loop)->current_lmul > RVV_M1)
518 return max_nregs > V_REG_NUM;
519 return false;
520 }
521 if (!program_points_per_bb.is_empty ())
522 {
523 for (hash_map<basic_block, vec<stmt_point>>::iterator iter
524 = program_points_per_bb.begin ();
525 iter != program_points_per_bb.end (); ++iter)
526 {
527 vec<stmt_point> program_points = (*iter).second;
528 if (!program_points.is_empty ())
529 program_points.release ();
530 }
531 program_points_per_bb.empty ();
532 }
533 return lmul > RVV_M1;
534 }
535
536 bool
537 costs::better_main_loop_than_p (const vector_costs *uncast_other) const
538 {
539 auto other = static_cast<const costs *> (uncast_other);
540
541 if (!flag_vect_cost_model)
542 return vector_costs::better_main_loop_than_p (other);
543
544 if (riscv_autovec_lmul == RVV_DYNAMIC)
545 {
546 bool post_dom_available_p = dom_info_available_p (CDI_POST_DOMINATORS);
547 if (!post_dom_available_p)
548 calculate_dominance_info (CDI_POST_DOMINATORS);
549 bool preferred_p = preferred_new_lmul_p (uncast_other);
550 if (!post_dom_available_p)
551 free_dominance_info (CDI_POST_DOMINATORS);
552 return preferred_p;
553 }
554
555 return vector_costs::better_main_loop_than_p (other);
556 }
557
558 unsigned
559 costs::add_stmt_cost (int count, vect_cost_for_stmt kind,
560 stmt_vec_info stmt_info, slp_tree, tree vectype,
561 int misalign, vect_cost_model_location where)
562 {
563 /* TODO: Use default STMT cost model.
564 We will support more accurate STMT cost model later. */
565 int stmt_cost = default_builtin_vectorization_cost (kind, vectype, misalign);
566 return record_stmt_cost (stmt_info, where, count * stmt_cost);
567 }
568
569 void
570 costs::finish_cost (const vector_costs *scalar_costs)
571 {
572 vector_costs::finish_cost (scalar_costs);
573 }
574
575 } // namespace riscv_vector