From: Carl Love Date: Fri, 26 Feb 2021 21:46:55 +0000 (-0600) Subject: PPC64: Reduced-Precision - bfloat16 Outer Product & Format Conversion Operations X-Git-Tag: VALGRIND_3_18_0~145 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=078f89e99b6f62e043f6138c6a7ae238befc1f2a;p=thirdparty%2Fvalgrind.git PPC64: Reduced-Precision - bfloat16 Outer Product & Format Conversion Operations Add support for: pmxvbf16ger2 Prefixed Masked VSX Vector bfloat16 GER (Rank-2 Update) pmxvbf16ger2pp Prefixed Masked VSX Vector bfloat16 GER (Rank-2 Update) Positive multiply, Positive accumulate pmxvbf16ger2pn Prefixed Masked VSX Vector bfloat16 GER (Rank-2 Update) Positive multiply, Negative accumulate pmxvbf16ger2np Prefixed Masked VSX Vector bfloat16 GER (Rank-2 Update) Negative multiply, Positive accumulate pmxvbf16ger2nn Prefixed Masked VSX Vector bfloat16 GER (Rank-2 Update) Negative multiply, Negative accumulate xvbf16ger2VSX Vector bfloat16 GER (Rank-2 Update) xvbf16ger2pp VSX Vector bfloat16 GER (Rank-2 Update) Positive multiply, Positive accumulate xvbf16ger2pn VSX Vector bfloat16 GER (Rank-2 Update) Positive multiply, Negative accumulate xvbf16ger2np VSX Vector bfloat16 GER (Rank-2 Update) Negative multiply, Positive accumulate xvbf16ger2nn VSX Vector bfloat16 GER (Rank-2 Update) Negative multiply, Negative accumulate xvcvbf16sp VSX Vector Convert bfloat16 to Single-Precision format xvcvspbf16 VSX Vector Convert with round Single-Precision to bfloat16 format --- diff --git a/VEX/priv/guest_ppc_defs.h b/VEX/priv/guest_ppc_defs.h index 54ce923a9b..d36d6c07d4 100644 --- a/VEX/priv/guest_ppc_defs.h +++ b/VEX/priv/guest_ppc_defs.h @@ -150,6 +150,8 @@ extern ULong convert_to_zoned_helper( ULong src_hi, ULong src_low, ULong return_upper ); extern ULong convert_to_national_helper( ULong src, ULong return_upper ); extern ULong convert_from_zoned_helper( ULong src_hi, ULong src_low ); +extern ULong convert_from_floattobf16_helper( ULong src ); +extern ULong convert_from_bf16tofloat_helper( ULong src ); extern ULong convert_from_national_helper( ULong src_hi, ULong src_low ); extern ULong generate_C_FPCC_helper( ULong size, ULong src_hi, ULong src ); extern ULong extract_bits_under_mask_helper( ULong src, ULong mask, @@ -201,6 +203,11 @@ extern void vector_gen_pvc_dword_mask_dirty_helper( VexGuestPPC64State* gst, #define XVF16GER2PN 0b10010010 #define XVF16GER2NP 0b01010010 #define XVF16GER2NN 0b11010010 +#define XVBF16GER2 0b00110011 +#define XVBF16GER2PP 0b00110010 +#define XVBF16GER2PN 0b10110010 +#define XVBF16GER2NP 0b01110010 +#define XVBF16GER2NN 0b11110010 #define XVF32GER 0b00011011 #define XVF32GERPP 0b00011010 #define XVF32GERPN 0b10011010 diff --git a/VEX/priv/guest_ppc_helpers.c b/VEX/priv/guest_ppc_helpers.c index 75497abb96..6bcee966d2 100644 --- a/VEX/priv/guest_ppc_helpers.c +++ b/VEX/priv/guest_ppc_helpers.c @@ -1905,6 +1905,125 @@ static Double conv_f16_to_double( ULong input ) # endif } +#define BF16_SIGN_MASK 0x8000 +#define BF16_EXP_MASK 0x7F80 +#define BF16_FRAC_MASK 0x007F +#define BF16_BIAS 127 +#define BF16_MAX_UNBIASED_EXP 127 +#define BF16_MIN_UNBIASED_EXP -126 +#define FLOAT_SIGN_MASK 0x80000000 +#define FLOAT_EXP_MASK 0x7F800000 +#define FLOAT_FRAC_MASK 0x007FFFFF +#define FLOAT_FRAC_BIT8 0x00008000 +#define FLOAT_BIAS 127 + +static Float conv_bf16_to_float( UInt input ) +{ + /* input is 16-bit bfloat. + bias +127, exponent 8-bits, fraction 7-bits + + output is 32-bit float. + bias +127, exponent 8-bits, fraction 22-bits + */ + + UInt input_exp, input_fraction, unbiased_exp; + UInt output_exp, output_fraction; + UInt sign; + union convert_t conv; + + sign = (UInt)(input & BF16_SIGN_MASK); + input_exp = input & BF16_EXP_MASK; + unbiased_exp = (input_exp >> 7) - (UInt)BF16_BIAS; + input_fraction = input & BF16_FRAC_MASK; + + if (((input_exp & BF16_EXP_MASK) == BF16_EXP_MASK) && + (input_fraction != 0)) { + /* input is NaN or SNaN, exp all 1's, fraction != 0 */ + output_exp = FLOAT_EXP_MASK; + output_fraction = input_fraction; + + } else if(((input_exp & BF16_EXP_MASK) == BF16_EXP_MASK) && + ( input_fraction == 0)) { + /* input is infinity, exp all 1's, fraction = 0 */ + output_exp = FLOAT_EXP_MASK; + output_fraction = 0; + + } else if((input_exp == 0) && (input_fraction == 0)) { + /* input is zero */ + output_exp = 0; + output_fraction = 0; + + } else if((input_exp == 0) && (input_fraction != 0)) { + /* input is denormal */ + output_fraction = input_fraction; + output_exp = (-(Int)BF16_BIAS + (Int)FLOAT_BIAS ) << 23; + + } else { + /* result is normal */ + output_exp = (unbiased_exp + FLOAT_BIAS) << 23; + output_fraction = input_fraction; + } + + conv.u32 = sign << (31 - 15) | output_exp | (output_fraction << (23-7)); + return conv.f; +} + +static UInt conv_float_to_bf16( UInt input ) +{ + /* input is 32-bit float stored as unsigned 32-bit. + bias +127, exponent 8-bits, fraction 23-bits + + output is 16-bit bfloat. + bias +127, exponent 8-bits, fraction 7-bits + + If the unbiased exponent of the input is greater than the max floating + point unbiased exponent value, the result of the floating point 16-bit + value is infinity. + */ + + UInt input_exp, input_fraction; + UInt output_exp, output_fraction; + UInt result, sign; + + sign = input & FLOAT_SIGN_MASK; + input_exp = input & FLOAT_EXP_MASK; + input_fraction = input & FLOAT_FRAC_MASK; + + if (((input_exp & FLOAT_EXP_MASK) == FLOAT_EXP_MASK) && + (input_fraction != 0)) { + /* input is NaN or SNaN, exp all 1's, fraction != 0 */ + output_exp = BF16_EXP_MASK; + output_fraction = (ULong)input_fraction >> (23 - 7); + } else if (((input_exp & FLOAT_EXP_MASK) == FLOAT_EXP_MASK) && + ( input_fraction == 0)) { + /* input is infinity, exp all 1's, fraction = 0 */ + output_exp = BF16_EXP_MASK; + output_fraction = 0; + } else if ((input_exp == 0) && (input_fraction == 0)) { + /* input is zero */ + output_exp = 0; + output_fraction = 0; + } else if ((input_exp == 0) && (input_fraction != 0)) { + /* input is denormal */ + output_exp = 0; + output_fraction = (ULong)input_fraction >> (23 - 7); + } else { + /* result is normal */ + output_exp = (input_exp - BF16_BIAS + FLOAT_BIAS) >> (23 - 7); + output_fraction = (ULong)input_fraction >> (23 - 7); + + /* Round result. Look at the 8th bit position of the 32-bit floating + pointt fraction. The F16 fraction is only 7 bits wide so if the 8th + bit of the F32 is a 1 we need to round up by adding 1 to the output + fraction. */ + if ((input_fraction & FLOAT_FRAC_BIT8) == FLOAT_FRAC_BIT8) + /* Round the F16 fraction up by 1 */ + output_fraction = output_fraction + 1; + } + + result = sign >> (31 - 15) | output_exp | output_fraction; + return result; +} static Float conv_double_to_float( Double src ) { @@ -1942,6 +2061,36 @@ static Float negate_float( Float input ) return -input; } +/* This C-helper takes a vector of two 32-bit floating point values + * and returns a vector containing two 16-bit bfloats. + input: word0 word1 + output 0x0 hword1 0x0 hword3 + Called from generated code. + */ +ULong convert_from_floattobf16_helper( ULong src ) { + ULong resultHi, resultLo; + + resultHi = (ULong)conv_float_to_bf16( (UInt)(src >> 32)); + resultLo = (ULong)conv_float_to_bf16( (UInt)(src & 0xFFFFFFFF)); + return (resultHi << 32) | resultLo; + +} + +/* This C-helper takes a vector of two 16-bit bfloating point values + * and returns a vector containing one 32-bit float. + input: 0x0 hword1 0x0 hword3 + output: word0 word1 + */ +ULong convert_from_bf16tofloat_helper( ULong src ) { + ULong result; + union convert_t conv; + conv.f = conv_bf16_to_float( (UInt)(src >> 32) ); + result = (ULong) conv.u32; + conv.f = conv_bf16_to_float( (UInt)(src & 0xFFFFFFFF)); + result = (result << 32) | (ULong) conv.u32; + return result; + } + void vsx_matrix_16bit_float_ger_dirty_helper( VexGuestPPC64State* gst, UInt offset_ACC, ULong srcA_hi, ULong srcA_lo, @@ -2002,24 +2151,44 @@ void vsx_matrix_16bit_float_ger_dirty_helper( VexGuestPPC64State* gst, srcB_word[0][j] = (UInt)((srcB_lo >> (16-16*j)) & mask); } + /* Note the isa is not consistent in the src naming. Will use the + naming src10, src11, src20, src21 used with xvf16ger2 instructions. + */ for( j = 0; j < 4; j++) { if (((pmsk >> 1) & 0x1) == 0) { src10 = 0; src20 = 0; } else { - src10 = conv_f16_to_double((ULong)srcA_word[i][0]); - src20 = conv_f16_to_double((ULong)srcB_word[j][0]); + if (( inst == XVF16GER2 ) || ( inst == XVF16GER2PP ) + || ( inst == XVF16GER2PN ) || ( inst == XVF16GER2NP ) + || ( inst == XVF16GER2NN )) { + src10 = conv_f16_to_double((ULong)srcA_word[i][0]); + src20 = conv_f16_to_double((ULong)srcB_word[j][0]); + } else { + /* Input is in bfloat format, result is stored in the + "traditional" 64-bit float format. */ + src10 = (double)conv_bf16_to_float((ULong)srcA_word[i][0]); + src20 = (double)conv_bf16_to_float((ULong)srcB_word[j][0]); + } } if ((pmsk & 0x1) == 0) { src11 = 0; src21 = 0; } else { - src11 = conv_f16_to_double((ULong)srcA_word[i][1]); - src21 = conv_f16_to_double((ULong)srcB_word[j][1]); + if (( inst == XVF16GER2 ) || ( inst == XVF16GER2PP ) + || ( inst == XVF16GER2PN ) || ( inst == XVF16GER2NP ) + || ( inst == XVF16GER2NN )) { + src11 = conv_f16_to_double((ULong)srcA_word[i][1]); + src21 = conv_f16_to_double((ULong)srcB_word[j][1]); + } else { + /* Input is in bfloat format, result is stored in the + "traditional" 64-bit float format. */ + src11 = (double)conv_bf16_to_float((ULong)srcA_word[i][1]); + src21 = (double)conv_bf16_to_float((ULong)srcB_word[j][1]); + } } - prod = src10 * src20; msum = prod + src11 * src21; @@ -2027,26 +2196,26 @@ void vsx_matrix_16bit_float_ger_dirty_helper( VexGuestPPC64State* gst, /* Note, we do not track the exception handling bits ox, ux, xx, si, mz, vxsnan and vximz in the FPSCR. */ - if ( inst == XVF16GER2 ) + if (( inst == XVF16GER2 ) || ( inst == XVBF16GER2 ) ) result[j] = reinterpret_float_as_int( conv_double_to_float(msum) ); - else if ( inst == XVF16GER2PP ) + else if (( inst == XVF16GER2PP ) || (inst == XVBF16GER2PP )) result[j] = reinterpret_float_as_int( conv_double_to_float(msum) + acc_word[j] ); - else if ( inst == XVF16GER2PN ) + else if (( inst == XVF16GER2PN ) || ( inst == XVBF16GER2PN )) result[j] = reinterpret_float_as_int( conv_double_to_float(msum) + negate_float( acc_word[j] ) ); - else if ( inst == XVF16GER2NP ) + else if (( inst == XVF16GER2NP ) || ( inst == XVBF16GER2NP )) result[j] = reinterpret_float_as_int( conv_double_to_float( negate_double( msum ) ) + acc_word[j] ); - else if ( inst == XVF16GER2NN ) + else if (( inst == XVF16GER2NN ) || ( inst == XVBF16GER2NN )) result[j] = reinterpret_float_as_int( conv_double_to_float( negate_double( msum ) ) + negate_float( acc_word[j] ) ); diff --git a/VEX/priv/guest_ppc_toIR.c b/VEX/priv/guest_ppc_toIR.c index 354be6b53d..20553a5394 100644 --- a/VEX/priv/guest_ppc_toIR.c +++ b/VEX/priv/guest_ppc_toIR.c @@ -5688,6 +5688,57 @@ static IRExpr * convert_from_national ( const VexAbiInfo* vbi, IRExpr *src ) { return mkexpr( result ); } +static IRExpr * vector_convert_floattobf16 ( const VexAbiInfo* vbi, + IRExpr *src ) { + /* The function takes 128-bit value containing four 32-bit floats and + returns a 128-bit value containint four 16-bit bfloats in the lower + halfwords. */ + + IRTemp resultHi = newTemp( Ity_I64); + IRTemp resultLo = newTemp( Ity_I64); + + assign( resultHi, + mkIRExprCCall( Ity_I64, 0 /*regparms*/, + "vector_convert_floattobf16_helper", + fnptr_to_fnentry( vbi, + &convert_from_floattobf16_helper ), + mkIRExprVec_1( unop( Iop_V128HIto64, src ) ) ) ); + + assign( resultLo, + mkIRExprCCall( Ity_I64, 0 /*regparms*/, + "vector_convert_floattobf16_helper", + fnptr_to_fnentry( vbi, + &convert_from_floattobf16_helper ), + mkIRExprVec_1( unop( Iop_V128to64, src ) ) ) ); + + return binop( Iop_64HLtoV128, mkexpr( resultHi ), mkexpr( resultLo ) ); +} + +static IRExpr * vector_convert_bf16tofloat ( const VexAbiInfo* vbi, + IRExpr *src ) { + /* The function takes 128-bit value containing four 16-bit bfloats in + the lower halfwords and returns a 128-bit value containint four + 32-bit floats. */ + IRTemp resultHi = newTemp( Ity_I64); + IRTemp resultLo = newTemp( Ity_I64); + + assign( resultHi, + mkIRExprCCall( Ity_I64, 0 /*regparms*/, + "vector_convert_bf16tofloat_helper", + fnptr_to_fnentry( vbi, + &convert_from_bf16tofloat_helper ), + mkIRExprVec_1( unop( Iop_V128HIto64, src ) ) ) ); + + assign( resultLo, + mkIRExprCCall( Ity_I64, 0 /*regparms*/, + "vector_convert_bf16tofloat_helper", + fnptr_to_fnentry( vbi, + &convert_from_bf16tofloat_helper ), + mkIRExprVec_1( unop( Iop_V128to64, src ) ) ) ); + + return binop( Iop_64HLtoV128, mkexpr( resultHi ), mkexpr( resultLo ) ); +} + static IRExpr * popcnt64 ( const VexAbiInfo* vbi, IRExpr *src ){ /* The function takes a 64-bit source and counts the number of bits in the @@ -5936,6 +5987,7 @@ static void vsx_matrix_ger ( const VexAbiInfo* vbi, case XVI16GER2: case XVI16GER2S: case XVF16GER2: + case XVBF16GER2: case XVF32GER: AT_fx = Ifx_Write; break; @@ -5943,6 +5995,10 @@ static void vsx_matrix_ger ( const VexAbiInfo* vbi, case XVI8GER4PP: case XVI16GER2PP: case XVI16GER2SPP: + case XVBF16GER2PP: + case XVBF16GER2PN: + case XVBF16GER2NP: + case XVBF16GER2NN: case XVF16GER2PP: case XVF16GER2PN: case XVF16GER2NP: @@ -23899,6 +23955,24 @@ dis_vxs_misc( UInt prefix, UInt theInstr, const VexAbiInfo* vbi, UInt opc2, mkexpr( sub_element1 ), mkexpr( sub_element0 ) ) ) ); + } else if ((inst_select == 16) && !prefix) { + IRTemp result = newTemp(Ity_V128); + UChar xT_addr = ifieldRegXT ( theInstr ); + UChar xB_addr = ifieldRegXB ( theInstr ); + /* Convert 16-bit bfloat to 32-bit float, not a prefix inst */ + DIP("xvcvbf16sp v%u,v%u\n", xT_addr, xB_addr); + assign( result, vector_convert_bf16tofloat( vbi, mkexpr( vB ) ) ); + putVSReg( XT, mkexpr( result) ); + + } else if ((inst_select == 17) && !prefix) { + IRTemp result = newTemp(Ity_V128); + UChar xT_addr = ifieldRegXT ( theInstr ); + UChar xB_addr = ifieldRegXB ( theInstr ); + /* Convert 32-bit float to 16-bit bfloat, not a prefix inst */ + DIP("xvcvspbf16 v%u,v%u\n", xT_addr, xB_addr); + assign( result, vector_convert_floattobf16( vbi, mkexpr( vB ) ) ); + putVSReg( XT, mkexpr( result) ); + } else if (inst_select == 23) { DIP("xxbrd v%u, v%u\n", (UInt)XT, (UInt)XB); @@ -34956,6 +35030,41 @@ static Bool dis_vsx_accumulator_prefix ( UInt prefix, UInt theInstr, getVSReg( rB_addr ), AT, ( ( inst_prefix << 8 ) | XO ) ); break; + case XVBF16GER2: + DIP("xvbf16ger2 %u,r%u, r%u\n", AT, rA_addr, rB_addr); + vsx_matrix_ger( vbi, MATRIX_16BIT_FLOAT_GER, + getVSReg( rA_addr ), + getVSReg( rB_addr ), AT, + ( ( inst_prefix << 8 ) | XO ) ); + break; + case XVBF16GER2PP: + DIP("xvbf16ger2pp %u,r%u, r%u\n", AT, rA_addr, rB_addr); + vsx_matrix_ger( vbi, MATRIX_16BIT_FLOAT_GER, + getVSReg( rA_addr ), + getVSReg( rB_addr ), AT, + ( ( inst_prefix << 8 ) | XO ) ); + break; + case XVBF16GER2PN: + DIP("xvbf16ger2pn %u,r%u, r%u\n", AT, rA_addr, rB_addr); + vsx_matrix_ger( vbi, MATRIX_16BIT_FLOAT_GER, + getVSReg( rA_addr ), + getVSReg( rB_addr ), AT, + ( ( inst_prefix << 8 ) | XO ) ); + break; + case XVBF16GER2NP: + DIP("xvbf16ger2np %u,r%u, r%u\n", AT, rA_addr, rB_addr); + vsx_matrix_ger( vbi, MATRIX_16BIT_FLOAT_GER, + getVSReg( rA_addr ), + getVSReg( rB_addr ), AT, + ( ( inst_prefix << 8 ) | XO ) ); + break; + case XVBF16GER2NN: + DIP("xvbf16ger2nn %u,r%u, r%u\n", AT, rA_addr, rB_addr); + vsx_matrix_ger( vbi, MATRIX_16BIT_FLOAT_GER, + getVSReg( rA_addr ), + getVSReg( rB_addr ), AT, + ( ( inst_prefix << 8 ) | XO ) ); + break; case XVF32GER: DIP("xvf32ger %u,r%u, r%u\n", AT, rA_addr, rB_addr); vsx_matrix_ger( vbi, MATRIX_32BIT_FLOAT_GER, @@ -35106,6 +35215,61 @@ static Bool dis_vsx_accumulator_prefix ( UInt prefix, UInt theInstr, AT, ( (MASKS << 9 ) | ( inst_prefix << 8 ) | XO ) ); break; + case XVBF16GER2: + PMSK = IFIELD( prefix, 14, 2); + XMSK = IFIELD( prefix, 4, 4); + YMSK = IFIELD( prefix, 0, 4); + DIP("pmxvbf16ger2 %u,r%u, r%u\n", AT, rA_addr, rB_addr); + vsx_matrix_ger( vbi, MATRIX_16BIT_FLOAT_GER, + getVSReg( rA_addr ), + getVSReg( rB_addr ), + AT, ( (MASKS << 9 ) + | ( inst_prefix << 8 ) | XO ) ); + break; + case XVBF16GER2PP: + PMSK = IFIELD( prefix, 14, 2); + XMSK = IFIELD( prefix, 4, 4); + YMSK = IFIELD( prefix, 0, 4); + DIP("pmxvbf16ger2pp %u,r%u, r%u\n", AT, rA_addr, rB_addr); + vsx_matrix_ger( vbi, MATRIX_16BIT_FLOAT_GER, + getVSReg( rA_addr ), + getVSReg( rB_addr ), + AT, ( (MASKS << 9 ) + | ( inst_prefix << 8 ) | XO ) ); + break; + case XVBF16GER2PN: + PMSK = IFIELD( prefix, 14, 2); + XMSK = IFIELD( prefix, 4, 4); + YMSK = IFIELD( prefix, 0, 4); + DIP("pmxvbf16ger2pn %u,r%u, r%u\n", AT, rA_addr, rB_addr); + vsx_matrix_ger( vbi, MATRIX_16BIT_FLOAT_GER, + getVSReg( rA_addr ), + getVSReg( rB_addr ), + AT, ( (MASKS << 9 ) + | ( inst_prefix << 8 ) | XO ) ); + break; + case XVBF16GER2NP: + PMSK = IFIELD( prefix, 14, 2); + XMSK = IFIELD( prefix, 4, 4); + YMSK = IFIELD( prefix, 0, 4); + DIP("pmxvbf16ger2np %u,r%u, r%u\n", AT, rA_addr, rB_addr); + vsx_matrix_ger( vbi, MATRIX_16BIT_FLOAT_GER, + getVSReg( rA_addr ), + getVSReg( rB_addr ), + AT, ( (MASKS << 9 ) + | ( inst_prefix << 8 ) | XO ) ); + break; + case XVBF16GER2NN: + PMSK = IFIELD( prefix, 14, 2); + XMSK = IFIELD( prefix, 4, 4); + YMSK = IFIELD( prefix, 0, 4); + DIP("pmxvbf16ger2nn %u,r%u, r%u\n", AT, rA_addr, rB_addr); + vsx_matrix_ger( vbi, MATRIX_16BIT_FLOAT_GER, + getVSReg( rA_addr ), + getVSReg( rB_addr ), + AT, ( (MASKS << 9 ) + | ( inst_prefix << 8 ) | XO ) ); + break; case XVF16GER2: PMSK = IFIELD( prefix, 14, 2); XMSK = IFIELD( prefix, 4, 4); @@ -36181,6 +36345,11 @@ DisResult disInstr_PPC_WRK ( (opc2 == XVI4GER8PP) || // xvi4ger8pp (opc2 == XVI8GER4) || // xvi8ger4 (opc2 == XVI8GER4PP) || // xvi8ger4pp + (opc2 == XVBF16GER2) || // xvbf16ger2 + (opc2 == XVBF16GER2PP) || // xvbf16ger2pp + (opc2 == XVBF16GER2PN) || // xvbf16ger2pn + (opc2 == XVBF16GER2NP) || // xvbf16ger2np + (opc2 == XVBF16GER2NN) || // xvbf16ger2nn (opc2 == XVF16GER2) || // xvf16ger2 (opc2 == XVF16GER2PP) || // xvf16ger2pp (opc2 == XVF16GER2PN) || // xvf16ger2pn