return res;
}
+/* Given that the target fully pipelines FMA instructions, return the latency
+ of MULT_EXPRs that can't be hidden by the FMAs. WIDTH is the number of
+ pipes. */
+
+static inline int
+get_mult_latency_consider_fma (int ops_num, int mult_num, int width)
+{
+ gcc_checking_assert (mult_num && mult_num <= ops_num);
+
+ /* For each partition, if mult_num == ops_num, there's latency(MULT)*2.
+ e.g:
+
+ A * B + C * D
+ =>
+ _1 = A * B;
+ _2 = .FMA (C, D, _1);
+
+ Otherwise there's latency(MULT)*1 in the first FMA. */
+ return CEIL (ops_num, width) == CEIL (mult_num, width) ? 2 : 1;
+}
+
/* Returns an optimal number of registers to use for computation of
given statements.
- LHS is the result ssa name of OPS. */
+ LHS is the result ssa name of OPS. MULT_NUM is number of sub-expressions
+ that are MULT_EXPRs, when OPS are PLUS_EXPRs or MINUS_EXPRs. */
static int
-get_reassociation_width (vec<operand_entry *> *ops, tree lhs,
+get_reassociation_width (vec<operand_entry *> *ops, int mult_num, tree lhs,
enum tree_code opc, machine_mode mode)
{
int param_width = param_tree_reassoc_width;
so we can perform a binary search for the minimal width that still
results in the optimal cycle count. */
width_min = 1;
- while (width > width_min)
+
+ /* If the target fully pipelines FMA instruction, the multiply part can start
+ already if its operands are ready. Assuming symmetric pipes are used for
+ FMUL/FADD/FMA, then for a sequence of FMA like:
+
+ _8 = .FMA (_2, _3, _1);
+ _9 = .FMA (_5, _4, _8);
+ _10 = .FMA (_7, _6, _9);
+
+ , if width=1, the latency is latency(MULT) + latency(ADD)*3.
+ While with width=2:
+
+ _8 = _4 * _5;
+ _9 = .FMA (_2, _3, _1);
+ _10 = .FMA (_6, _7, _8);
+ _11 = _9 + _10;
+
+ , it is latency(MULT)*2 + latency(ADD)*2. Assuming latency(MULT) >=
+ latency(ADD), the first variant is preferred.
+
+ Find out if we can get a smaller width considering FMA. */
+ if (width > 1 && mult_num && param_fully_pipelined_fma)
{
- int width_mid = (width + width_min) / 2;
+ /* When param_fully_pipelined_fma is set, assume FMUL and FMA use the
+ same units that can also do FADD. For other scenarios, such as when
+ FMUL and FADD are using separated units, the following code may not
+ appy. */
+ int width_mult = targetm.sched.reassociation_width (MULT_EXPR, mode);
+ gcc_checking_assert (width_mult <= width);
+
+ /* Latency of MULT_EXPRs. */
+ int lat_mul
+ = get_mult_latency_consider_fma (ops_num, mult_num, width_mult);
+
+ /* Quick search might not apply. So start from 1. */
+ for (int i = 1; i < width_mult; i++)
+ {
+ int lat_mul_new
+ = get_mult_latency_consider_fma (ops_num, mult_num, i);
+ int lat_add_new = get_required_cycles (ops_num, i);
- if (get_required_cycles (ops_num, width_mid) == cycles_best)
- width = width_mid;
- else if (width_min < width_mid)
- width_min = width_mid;
- else
- break;
+ /* Assume latency(MULT) >= latency(ADD). */
+ if (lat_mul - lat_mul_new >= lat_add_new - cycles_best)
+ {
+ width = i;
+ break;
+ }
+ }
+ }
+ else
+ {
+ while (width > width_min)
+ {
+ int width_mid = (width + width_min) / 2;
+
+ if (get_required_cycles (ops_num, width_mid) == cycles_best)
+ width = width_mid;
+ else if (width_min < width_mid)
+ width_min = width_mid;
+ else
+ break;
+ }
}
/* If there's loop dependent FMA result, return width=2 to avoid it. This is
Rearrange ops to -> e + a * b + c * d generates:
_4 = .FMA (c_7(D), d_8(D), _3);
- _11 = .FMA (a_5(D), b_6(D), _4); */
-static bool
+ _11 = .FMA (a_5(D), b_6(D), _4);
+
+ Return the number of MULT_EXPRs in the chain. */
+static int
rank_ops_for_fma (vec<operand_entry *> *ops)
{
operand_entry *oe;
if (TREE_CODE (oe->op) == SSA_NAME)
{
gimple *def_stmt = SSA_NAME_DEF_STMT (oe->op);
- if (is_gimple_assign (def_stmt)
- && gimple_assign_rhs_code (def_stmt) == MULT_EXPR)
- ops_mult.safe_push (oe);
+ if (is_gimple_assign (def_stmt))
+ {
+ if (gimple_assign_rhs_code (def_stmt) == MULT_EXPR)
+ ops_mult.safe_push (oe);
+ /* A negate on the multiplication leads to FNMA. */
+ else if (gimple_assign_rhs_code (def_stmt) == NEGATE_EXPR
+ && TREE_CODE (gimple_assign_rhs1 (def_stmt)) == SSA_NAME)
+ {
+ gimple *neg_def_stmt
+ = SSA_NAME_DEF_STMT (gimple_assign_rhs1 (def_stmt));
+ if (is_gimple_assign (neg_def_stmt)
+ && gimple_bb (neg_def_stmt) == gimple_bb (def_stmt)
+ && gimple_assign_rhs_code (neg_def_stmt) == MULT_EXPR)
+ ops_mult.safe_push (oe);
+ else
+ ops_others.safe_push (oe);
+ }
+ else
+ ops_others.safe_push (oe);
+ }
else
ops_others.safe_push (oe);
}
Putting ops that not def from mult in front can generate more FMAs.
2. If all ops are defined with mult, we don't need to rearrange them. */
- if (ops_mult.length () >= 2 && ops_mult.length () != ops_length)
+ unsigned mult_num = ops_mult.length ();
+ if (mult_num >= 2 && mult_num != ops_length)
{
/* Put no-mult ops and mult ops alternately at the end of the
queue, which is conducive to generating more FMA and reducing the
if (opindex > 0)
opindex--;
}
- return true;
}
- return false;
+ return mult_num;
}
/* Reassociate expressions in basic block BB and its post-dominator as
children.
{
machine_mode mode = TYPE_MODE (TREE_TYPE (lhs));
int ops_num = ops.length ();
- int width;
- bool has_fma = false;
+ int width = 0;
+ int mult_num = 0;
/* For binary bit operations, if there are at least 3
operands and the last operand in OPS is a constant,
opt_type)
&& (rhs_code == PLUS_EXPR || rhs_code == MINUS_EXPR))
{
- has_fma = rank_ops_for_fma (&ops);
+ mult_num = rank_ops_for_fma (&ops);
}
/* Only rewrite the expression tree to parallel in the
last reassoc pass to avoid useless work back-and-forth
with initial linearization. */
+ bool has_fma = mult_num >= 2 && mult_num != ops_num;
if (!reassoc_insert_powi_p
&& ops.length () > 3
- && (width
- = get_reassociation_width (&ops, lhs, rhs_code, mode))
+ && (width = get_reassociation_width (&ops, mult_num, lhs,
+ rhs_code, mode))
> 1)
{
if (dump_file && (dump_flags & TDF_DETAILS))
if (len >= 3
&& (!has_fma
/* width > 1 means ranking ops results in better
- parallelism. */
- || get_reassociation_width (&ops, lhs, rhs_code,
- mode)
- > 1))
+ parallelism. Check current value to avoid
+ calling get_reassociation_width again. */
+ || (width != 1
+ && get_reassociation_width (
+ &ops, mult_num, lhs, rhs_code, mode)
+ > 1)))
swap_ops_for_binary_stmt (ops, len - 3);
new_lhs = rewrite_expr_tree (stmt, rhs_code, 0, ops,