From: Richard Levitte Date: Wed, 3 Dec 2025 20:21:37 +0000 (+0100) Subject: OSSL_FN: Refactor OSSL_FN_add() and OSSL_FN_sub() for truncation X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fheads%2Ffeature%2Fossl_fn;p=thirdparty%2Fopenssl.git OSSL_FN: Refactor OSSL_FN_add() and OSSL_FN_sub() for truncation OSSL_FN_mul() set a path that wasn't considered for OSSL_FN_add() and OSSL_FN_sub(); a truncated result if the result OSSL_FN isn't large enough to contain the full result. This is done to keep the OSSL_FN API consistent, with a (tentative) bonus, that the function calls become more constant time accross repeated calls with the same size for operands and result. Reviewed-by: Matt Caswell Reviewed-by: Tomas Mraz (Merged from https://github.com/openssl/openssl/pull/29309) --- diff --git a/crypto/fn/fn_addsub.c b/crypto/fn/fn_addsub.c index 6c81c3c7330..53b05308193 100644 --- a/crypto/fn/fn_addsub.c +++ b/crypto/fn/fn_addsub.c @@ -33,33 +33,53 @@ int OSSL_FN_add(OSSL_FN *r, const OSSL_FN *a, const OSSL_FN *b) size_t max = a->dsize; size_t min = b->dsize; size_t rs = r->dsize; + const OSSL_FN_ULONG *ap = a->d; + const OSSL_FN_ULONG *bp = b->d; + OSSL_FN_ULONG *rp = r->d; /* - * 'r' must be able to contain the result. This is on the caller. + * Three stages, after which there's a possible return. + * Each stage is limited by the number of remaining limbs + * in |r|. + * + * For each stage, |stage_limbs| is used to hold the number + * of limbs being treated in that stage, and |borrow| is + * used to transport the borrow from one stage to the other. */ - if (!ossl_assert(max <= rs)) { - ERR_raise(ERR_LIB_OSSL_FN, OSSL_FN_R_RESULT_ARG_TOO_SMALL); - return 0; - } + size_t stage_limbs; + OSSL_FN_ULONG carry; - const OSSL_FN_ULONG *ap = a->d; - const OSSL_FN_ULONG *bp = b->d; + /* Stage 1 */ - OSSL_FN_ULONG *rp = r->d; - OSSL_FN_ULONG carry = bn_add_words(rp, ap, bp, (int)min); + stage_limbs = (min > rs) ? rs : min; + carry = bn_add_words(rp, ap, bp, (int)stage_limbs); + if (stage_limbs == rs) + return 1; + /* At this point, we know that |min| limbs have been used so far */ rp += min; ap += min; - for (size_t dif = max - min; dif > 0; dif--, ap++, rp++) { + /* Stage 2 */ + + stage_limbs = ((max > rs) ? rs : max) - min; + + for (size_t dif = stage_limbs; dif > 0; dif--, ap++, rp++) { OSSL_FN_ULONG t1 = *ap; OSSL_FN_ULONG t2 = (t1 + carry) & OSSL_FN_MASK; *rp = t2; carry &= (t2 == 0); } + if (stage_limbs == rs) + return 1; - for (size_t dif = r->dsize - max; dif > 0; dif--, rp++) { + /* Stage 3 */ + + /* We know that |max| limbs have been used */ + stage_limbs = rs - max; + + for (size_t dif = stage_limbs; dif > 0; dif--, rp++) { OSSL_FN_ULONG t1 = 0; OSSL_FN_ULONG t2 = (t1 + carry) & OSSL_FN_MASK; @@ -76,45 +96,61 @@ int OSSL_FN_sub(OSSL_FN *r, const OSSL_FN *a, const OSSL_FN *b) size_t max = (a->dsize >= b->dsize) ? a->dsize : b->dsize; size_t min = (a->dsize <= b->dsize) ? a->dsize : b->dsize; size_t rs = r->dsize; - - /* - * 'r' must be able to contain the result. This is on the caller. - */ - if (!ossl_assert(max <= rs)) { - ERR_raise(ERR_LIB_OSSL_FN, OSSL_FN_R_RESULT_ARG_TOO_SMALL); - return 0; - } - const OSSL_FN_ULONG *ap = a->d; const OSSL_FN_ULONG *bp = b->d; OSSL_FN_ULONG *rp = r->d; - OSSL_FN_ULONG borrow = bn_sub_words(rp, ap, bp, (int)min); /* - * TODO(FIXNUM): everything following isn't strictly constant-time, - * and could use improvement in that regard. + * Three stages, after which there's a possible return. + * Each stage is limited by the number of remaining limbs + * in |r|. + * + * For each stage, |stage_limbs| is used to hold the number + * of limbs being treated in that stage, and |borrow| is + * used to transport the borrow from one stage to the other. */ + size_t stage_limbs; + OSSL_FN_ULONG borrow; + + /* Stage 1 */ + + stage_limbs = (min > rs) ? rs : min; + borrow = bn_sub_words(rp, ap, bp, (int)stage_limbs); + if (stage_limbs == rs) + return 1; + /* At this point, we know that |min| limbs have been used so far */ ap += min; bp += min; rp += min; + /* Stage 2 */ + const OSSL_FN_ULONG *maxp = (a->dsize >= b->dsize) ? ap : bp; + const OSSL_FN_ULONG s2_mask1 = (a->dsize >= b->dsize) ? OSSL_FN_MASK : 0; + const OSSL_FN_ULONG s2_mask2 = ~s2_mask1; - /* "sign" borrow, depending on if maxp == ap or maxp == bp */ - borrow *= (OSSL_FN_ULONG)((a->dsize >= b->dsize) ? 1 : -1); + stage_limbs = ((max > rs) ? rs : max) - min; /* calculate the result of borrowing from more significant limbs */ - for (size_t dif = max - min; dif > 0; dif--, maxp++, rp++) { - OSSL_FN_ULONG t1 = *maxp; - OSSL_FN_ULONG t2 = (t1 - borrow) & OSSL_FN_MASK; + for (size_t dif = stage_limbs; dif > 0; dif--, maxp++, rp++) { + OSSL_FN_ULONG t1 = (*maxp & s2_mask1); + OSSL_FN_ULONG t2 = (*maxp & s2_mask2); + OSSL_FN_ULONG t3 = (t1 - t2 - borrow) & OSSL_FN_MASK; - *rp = t2; - borrow &= (t1 == 0); + *rp = t3; + borrow &= (t1 <= t2); } + if (stage_limbs == rs) + return 1; + + /* Stage 3 */ + + /* We know that |max| limbs have been used */ + stage_limbs = rs - max; /* Finally, fill in the rest of the result array by borrowing from zeros */ - for (size_t dif = rs - max; dif > 0; dif--, rp++) { + for (size_t dif = stage_limbs; dif > 0; dif--, rp++) { OSSL_FN_ULONG t1 = 0; OSSL_FN_ULONG t2 = (t1 - borrow) & OSSL_FN_MASK; diff --git a/test/fn_api_test.c b/test/fn_api_test.c index 880d0b275c5..ad0dc2a0283 100644 --- a/test/fn_api_test.c +++ b/test/fn_api_test.c @@ -41,7 +41,8 @@ static int pollute(OSSL_FN *f, size_t start, size_t end) return 1; } -static int check_zero(const OSSL_FN *f, size_t start, size_t end) +static int check_limbs_value(const OSSL_FN *f, size_t start, size_t end, + OSSL_FN_ULONG value) { const OSSL_FN_ULONG *u = ossl_fn_get_words(f); size_t l = ossl_fn_get_dsize(f); @@ -52,8 +53,10 @@ static int check_zero(const OSSL_FN *f, size_t start, size_t end) start = end; for (size_t i = start; i < end; i++) - if (u[i] != 0) + if (!TEST_size_t_eq(u[i], value)) { + TEST_note("start = %zu, end = %zu, i = %zu\n", start, end, i); return 0; + } return 1; } @@ -95,6 +98,11 @@ struct test_case_st { /* Number of limbs to compare the result's OSSL_FN_ULONG array against ex */ size_t check_size; + + /* When the result is larger than check_size, the expected extended value */ + OSSL_FN_ULONG extended_limb_value; +#define EXTENDED_LIMB_ZERO ((OSSL_FN_ULONG)0) +#define EXTENDED_LIMB_MINUS_ONE ((OSSL_FN_ULONG)-1) }; static const OSSL_FN_ULONG ex_add_num0_num0[] = { @@ -131,18 +139,64 @@ static const OSSL_FN_ULONG ex_add_num3_num3[] = { OSSL_FN_ULONG_C(0x1), }; -#define ADD_CASE(i, op1, op2, ex) \ - { \ - /* op1 */ op1, \ - /* op1_size */ LIMBSOF(op1), \ - /* op2 */ op2, \ - /* op2_size */ LIMBSOF(op2), \ - /* ex */ ex, \ - /* ex_size */ LIMBSOF(ex), \ - /* op1_live_size */ LIMBSOF(op1), \ - /* op2_live_size */ LIMBSOF(op2), \ - /* res_live_size */ LIMBSOF(ex), \ - /* check_size */ LIMBSOF(ex), \ +static int test_add_common(struct test_case_st test_case) +{ + const OSSL_FN_ULONG *n1 = test_case.op1; + size_t n1_limbs = test_case.op1_size; + const OSSL_FN_ULONG *n2 = test_case.op2; + size_t n2_limbs = test_case.op2_size; + const OSSL_FN_ULONG *ex = test_case.ex; + size_t n1_new_limbs = test_case.op1_live_size; + size_t n2_new_limbs = test_case.op2_live_size; + size_t res_limbs = test_case.res_live_size; + size_t check_limbs = test_case.check_size; + OSSL_FN_ULONG extended_value = test_case.extended_limb_value; + int ret = 1; + OSSL_FN *fn1 = NULL, *fn2 = NULL, *res = NULL; + const OSSL_FN_ULONG *u = NULL; + + /* To test that OSSL_FN_add() does a complete job, 'res' is pre-polluted */ + + if (!TEST_ptr(fn1 = OSSL_FN_new_limbs(n1_new_limbs)) + || !TEST_ptr(fn2 = OSSL_FN_new_limbs(n2_new_limbs)) + || !TEST_true(ossl_fn_set_words(fn1, n1, n1_limbs)) + || !TEST_true(ossl_fn_set_words(fn2, n2, n2_limbs)) + || !TEST_ptr(res = OSSL_FN_new_limbs(res_limbs)) + || !TEST_true(pollute(res, 0, res_limbs))) { + res = 0; + /* There's no way to continue tests in this case */ + goto end; + } + + if (!TEST_true(OSSL_FN_add(res, fn1, fn2)) + || !TEST_ptr(u = ossl_fn_get_words(res)) + || !TEST_mem_eq(u, check_limbs * OSSL_FN_BYTES, + ex, check_limbs * OSSL_FN_BYTES) + || !TEST_true(check_limbs_value(res, check_limbs, res_limbs, + extended_value))) + ret = 0; + +end: + OSSL_FN_free(fn1); + OSSL_FN_free(fn2); + OSSL_FN_free(res); + + return ret; +} + +#define ADD_CASE(i, op1, op2, ex) \ + { \ + /* op1 */ op1, \ + /* op1_size */ LIMBSOF(op1), \ + /* op2 */ op2, \ + /* op2_size */ LIMBSOF(op2), \ + /* ex */ ex, \ + /* ex_size */ LIMBSOF(ex), \ + /* op1_live_size */ LIMBSOF(op1) + 1, \ + /* op2_live_size */ LIMBSOF(op2) + 2, \ + /* res_live_size */ LIMBSOF(ex) + 3, \ + /* check_size */ LIMBSOF(ex), \ + /* extended_limb_value */ EXTENDED_LIMB_ZERO, \ } static struct test_case_st test_add_cases[] = { @@ -166,35 +220,46 @@ static struct test_case_st test_add_cases[] = { static int test_add(int i) { - const OSSL_FN_ULONG *n1 = test_add_cases[i].op1; - size_t n1_limbs = test_add_cases[i].op1_size; - const OSSL_FN_ULONG *n2 = test_add_cases[i].op2; - size_t n2_limbs = test_add_cases[i].op2_size; - const OSSL_FN_ULONG *ex = test_add_cases[i].ex; - size_t ex_limbs = test_add_cases[i].ex_size; - size_t check_limbs = test_add_cases[i].check_size; - int ret = 1; - OSSL_FN *fn1 = NULL, *fn2 = NULL, *res = NULL; - const OSSL_FN_ULONG *u = NULL; - - /* To test that OSSL_FN_add() does a complete job, 'res' is pre-polluted */ + return test_add_common(test_add_cases[i]); +} - if (!TEST_ptr(fn1 = OSSL_FN_new_limbs(n1_limbs)) - || !TEST_ptr(fn2 = OSSL_FN_new_limbs(n2_limbs)) - || !TEST_ptr(res = OSSL_FN_new_limbs(ex_limbs)) - || !TEST_true(pollute(res, 0, ex_limbs)) - || !TEST_true(ossl_fn_set_words(fn1, n1, n1_limbs)) - || !TEST_true(ossl_fn_set_words(fn2, n2, n2_limbs)) - || !TEST_true(OSSL_FN_add(res, fn1, fn2)) - || !TEST_ptr(u = ossl_fn_get_words(res)) - || !TEST_mem_eq(u, ossl_fn_get_dsize(res) * OSSL_FN_BYTES, - ex, check_limbs * OSSL_FN_BYTES)) - ret = 0; - OSSL_FN_free(fn1); - OSSL_FN_free(fn2); - OSSL_FN_free(res); +#define ADD_TRUNCATED_CASE(i, op1, op2, ex) \ + { \ + /* op1 */ op1, \ + /* op1_size */ LIMBSOF(op1), \ + /* op2 */ op2, \ + /* op2_size */ LIMBSOF(op2), \ + /* ex */ ex, \ + /* ex_size */ LIMBSOF(ex), \ + /* op1_live_size */ LIMBSOF(op1) + 1, \ + /* op2_live_size */ LIMBSOF(op2) + 2, \ + /* res_live_size */ LIMBSOF(ex) - 1, \ + /* check_size */ LIMBSOF(ex) - 1, \ + /* extended_limb_value */ EXTENDED_LIMB_ZERO, \ + } - return ret; +static struct test_case_st test_add_truncated_cases[] = { + ADD_TRUNCATED_CASE(1, num0, num0, ex_add_num0_num0), + ADD_TRUNCATED_CASE(2, num0, num1, ex_add_num0_num1), + ADD_TRUNCATED_CASE(3, num0, num2, ex_add_num0_num2), + ADD_TRUNCATED_CASE(4, num0, num3, ex_add_num0_num3), + ADD_TRUNCATED_CASE(5, num1, num0, ex_add_num0_num1), /* Commutativity check */ + ADD_TRUNCATED_CASE(6, num1, num1, ex_add_num1_num1), + ADD_TRUNCATED_CASE(7, num1, num2, ex_add_num1_num2), + ADD_TRUNCATED_CASE(8, num1, num3, ex_add_num1_num3), + ADD_TRUNCATED_CASE(9, num2, num0, ex_add_num0_num2), /* Commutativity check */ + ADD_TRUNCATED_CASE(10, num2, num1, ex_add_num1_num2), /* Commutativity check */ + ADD_TRUNCATED_CASE(11, num2, num2, ex_add_num2_num2), + ADD_TRUNCATED_CASE(12, num2, num3, ex_add_num2_num3), + ADD_TRUNCATED_CASE(13, num3, num0, ex_add_num0_num3), /* Commutativity check */ + ADD_TRUNCATED_CASE(14, num3, num1, ex_add_num1_num3), /* Commutativity check */ + ADD_TRUNCATED_CASE(15, num3, num2, ex_add_num2_num3), /* Commutativity check */ + ADD_TRUNCATED_CASE(16, num3, num3, ex_add_num3_num3), +}; + +static int test_add_truncated(int i) +{ + return test_add_common(test_add_truncated_cases[i]); } static const OSSL_FN_ULONG ex_sub_num0_num0[] = { @@ -252,65 +317,44 @@ static const OSSL_FN_ULONG ex_sub_num3_num3[] = { OSSL_FN_ULONG64_C(0x00000000, 0x00000000), }; -#define SUB_CASE(i, op1, op2, ex) \ - { \ - /* op1, with size */ op1, \ - LIMBSOF(op1), \ - /* op2, with size */ op2, \ - LIMBSOF(op2), \ - /* ex, with size */ ex, \ - LIMBSOF(ex), \ - /* op1_live_size */ LIMBSOF(op1), \ - /* op2_live_size */ LIMBSOF(op2), \ - /* res_live_size */ LIMBSOF(ex), \ - /* check_size */ LIMBSOF(ex), \ - } - -static struct test_case_st test_sub_cases[] = { - SUB_CASE(1, num0, num0, ex_sub_num0_num0), - SUB_CASE(2, num0, num1, ex_sub_num0_num1), - SUB_CASE(3, num0, num2, ex_sub_num0_num2), - SUB_CASE(4, num0, num3, ex_sub_num0_num3), - SUB_CASE(5, num1, num0, ex_sub_num1_num0), - SUB_CASE(6, num1, num1, ex_sub_num1_num1), - SUB_CASE(7, num1, num2, ex_sub_num1_num2), - SUB_CASE(8, num1, num3, ex_sub_num1_num3), - SUB_CASE(9, num2, num0, ex_sub_num2_num0), - SUB_CASE(10, num2, num1, ex_sub_num2_num1), - SUB_CASE(11, num2, num2, ex_sub_num2_num2), - SUB_CASE(12, num2, num3, ex_sub_num2_num3), - SUB_CASE(13, num3, num0, ex_sub_num3_num0), - SUB_CASE(14, num3, num1, ex_sub_num3_num1), - SUB_CASE(15, num3, num2, ex_sub_num3_num2), - SUB_CASE(16, num3, num3, ex_sub_num3_num3), -}; - -static int test_sub(int i) +static int test_sub_common(struct test_case_st test_case) { - const OSSL_FN_ULONG *n1 = test_sub_cases[i].op1; - size_t n1_limbs = test_sub_cases[i].op1_size; - const OSSL_FN_ULONG *n2 = test_sub_cases[i].op2; - size_t n2_limbs = test_sub_cases[i].op2_size; - const OSSL_FN_ULONG *ex = test_sub_cases[i].ex; - size_t ex_limbs = test_sub_cases[i].ex_size; - size_t check_limbs = test_sub_cases[i].check_size; + const OSSL_FN_ULONG *n1 = test_case.op1; + size_t n1_limbs = test_case.op1_size; + const OSSL_FN_ULONG *n2 = test_case.op2; + size_t n2_limbs = test_case.op2_size; + const OSSL_FN_ULONG *ex = test_case.ex; + size_t n1_new_limbs = test_case.op1_live_size; + size_t n2_new_limbs = test_case.op2_live_size; + size_t res_limbs = test_case.res_live_size; + size_t check_limbs = test_case.check_size; + OSSL_FN_ULONG extended_value = test_case.extended_limb_value; int ret = 1; OSSL_FN *fn1 = NULL, *fn2 = NULL, *res = NULL; const OSSL_FN_ULONG *u = NULL; /* To test that OSSL_FN_sub() does a complete job, 'res' is pre-polluted */ - if (!TEST_ptr(fn1 = OSSL_FN_new_limbs(n1_limbs)) - || !TEST_ptr(fn2 = OSSL_FN_new_limbs(n2_limbs)) - || !TEST_ptr(res = OSSL_FN_new_limbs(ex_limbs)) - || !TEST_true(pollute(res, 0, ex_limbs)) + if (!TEST_ptr(fn1 = OSSL_FN_new_limbs(n1_new_limbs)) + || !TEST_ptr(fn2 = OSSL_FN_new_limbs(n2_new_limbs)) || !TEST_true(ossl_fn_set_words(fn1, n1, n1_limbs)) || !TEST_true(ossl_fn_set_words(fn2, n2, n2_limbs)) - || !TEST_true(OSSL_FN_sub(res, fn1, fn2)) + || !TEST_ptr(res = OSSL_FN_new_limbs(res_limbs)) + || !TEST_true(pollute(res, 0, res_limbs))) { + res = 0; + /* There's no way to continue tests in this case */ + goto end; + } + + if (!TEST_true(OSSL_FN_sub(res, fn1, fn2)) || !TEST_ptr(u = ossl_fn_get_words(res)) - || !TEST_mem_eq(u, ossl_fn_get_dsize(res) * OSSL_FN_BYTES, - ex, check_limbs * OSSL_FN_BYTES)) + || !TEST_mem_eq(u, check_limbs * OSSL_FN_BYTES, + ex, check_limbs * OSSL_FN_BYTES) + || !TEST_true(check_limbs_value(res, check_limbs, res_limbs, + extended_value))) ret = 0; + +end: OSSL_FN_free(fn1); OSSL_FN_free(fn2); OSSL_FN_free(res); @@ -318,6 +362,84 @@ static int test_sub(int i) return ret; } +#define SUB_CASE(i, op1, op2, ex, ext) \ + { \ + /* op1 */ op1, \ + /* op1_size */ LIMBSOF(op1), \ + /* op2 */ op2, \ + /* op2_size */ LIMBSOF(op2), \ + /* ex */ ex, \ + /* ex_size */ LIMBSOF(ex), \ + /* op1_live_size */ LIMBSOF(op1) + 1, \ + /* op2_live_size */ LIMBSOF(op2) + 2, \ + /* res_live_size */ LIMBSOF(ex) + 3, \ + /* check_size */ LIMBSOF(ex), \ + /* extended_limb_value */ (ext), \ + } + +static struct test_case_st test_sub_cases[] = { + SUB_CASE(1, num0, num0, ex_sub_num0_num0, EXTENDED_LIMB_ZERO), + SUB_CASE(2, num0, num1, ex_sub_num0_num1, EXTENDED_LIMB_ZERO), + SUB_CASE(3, num0, num2, ex_sub_num0_num2, EXTENDED_LIMB_ZERO), + SUB_CASE(4, num0, num3, ex_sub_num0_num3, EXTENDED_LIMB_MINUS_ONE), + SUB_CASE(5, num1, num0, ex_sub_num1_num0, EXTENDED_LIMB_MINUS_ONE), + SUB_CASE(6, num1, num1, ex_sub_num1_num1, EXTENDED_LIMB_ZERO), + SUB_CASE(7, num1, num2, ex_sub_num1_num2, EXTENDED_LIMB_MINUS_ONE), + SUB_CASE(8, num1, num3, ex_sub_num1_num3, EXTENDED_LIMB_MINUS_ONE), + SUB_CASE(9, num2, num0, ex_sub_num2_num0, EXTENDED_LIMB_MINUS_ONE), + SUB_CASE(10, num2, num1, ex_sub_num2_num1, EXTENDED_LIMB_ZERO), + SUB_CASE(11, num2, num2, ex_sub_num2_num2, EXTENDED_LIMB_ZERO), + SUB_CASE(12, num2, num3, ex_sub_num2_num3, EXTENDED_LIMB_MINUS_ONE), + SUB_CASE(13, num3, num0, ex_sub_num3_num0, EXTENDED_LIMB_ZERO), + SUB_CASE(14, num3, num1, ex_sub_num3_num1, EXTENDED_LIMB_ZERO), + SUB_CASE(15, num3, num2, ex_sub_num3_num2, EXTENDED_LIMB_ZERO), + SUB_CASE(16, num3, num3, ex_sub_num3_num3, EXTENDED_LIMB_ZERO), +}; + +static int test_sub(int i) +{ + return test_sub_common(test_sub_cases[i]); +} + +#define SUB_TRUNCATED_CASE(i, op1, op2, ex) \ + { \ + /* op1 */ op1, \ + /* op1_size */ LIMBSOF(op1), \ + /* op2 */ op2, \ + /* op2_size */ LIMBSOF(op2), \ + /* ex */ ex, \ + /* ex_size */ LIMBSOF(ex), \ + /* op1_live_size */ LIMBSOF(op1) + 1, \ + /* op2_live_size */ LIMBSOF(op2) + 2, \ + /* res_live_size */ LIMBSOF(ex) - 1, \ + /* check_size */ LIMBSOF(ex) - 1, \ + /* extended_limb_value */ EXTENDED_LIMB_ZERO, \ + } + +static struct test_case_st test_sub_truncated_cases[] = { + SUB_TRUNCATED_CASE(1, num0, num0, ex_sub_num0_num0), + SUB_TRUNCATED_CASE(2, num0, num1, ex_sub_num0_num1), + SUB_TRUNCATED_CASE(3, num0, num2, ex_sub_num0_num2), + SUB_TRUNCATED_CASE(4, num0, num3, ex_sub_num0_num3), + SUB_TRUNCATED_CASE(5, num1, num0, ex_sub_num1_num0), + SUB_TRUNCATED_CASE(6, num1, num1, ex_sub_num1_num1), + SUB_TRUNCATED_CASE(7, num1, num2, ex_sub_num1_num2), + SUB_TRUNCATED_CASE(8, num1, num3, ex_sub_num1_num3), + SUB_TRUNCATED_CASE(9, num2, num0, ex_sub_num2_num0), + SUB_TRUNCATED_CASE(10, num2, num1, ex_sub_num2_num1), + SUB_TRUNCATED_CASE(11, num2, num2, ex_sub_num2_num2), + SUB_TRUNCATED_CASE(12, num2, num3, ex_sub_num2_num3), + SUB_TRUNCATED_CASE(13, num3, num0, ex_sub_num3_num0), + SUB_TRUNCATED_CASE(14, num3, num1, ex_sub_num3_num1), + SUB_TRUNCATED_CASE(15, num3, num2, ex_sub_num3_num2), + SUB_TRUNCATED_CASE(16, num3, num3, ex_sub_num3_num3), +}; + +static int test_sub_truncated(int i) +{ + return test_sub_common(test_sub_truncated_cases[i]); +} + /* A set of expected results, also in OSSL_FN_ULONG array form */ static const OSSL_FN_ULONG ex_mul_num0_num0[] = { OSSL_FN_ULONG64_C(0x00000000, 0x00000001), @@ -493,6 +615,7 @@ static int test_mul_common(struct test_case_st test_case) size_t n2_new_limbs = test_case.op2_live_size; size_t res_limbs = test_case.res_live_size; size_t check_limbs = test_case.check_size; + OSSL_FN_ULONG extended_value = test_case.extended_limb_value; OSSL_FN *fn1 = NULL, *fn2 = NULL, *res = NULL; const OSSL_FN_ULONG *u = NULL; @@ -515,7 +638,8 @@ static int test_mul_common(struct test_case_st test_case) || !TEST_ptr(u = ossl_fn_get_words(res)) || !TEST_mem_eq(u, check_limbs * OSSL_FN_BYTES, ex, check_limbs * OSSL_FN_BYTES) - || !TEST_true(check_zero(res, check_limbs, res_limbs))) + || !TEST_true(check_limbs_value(res, check_limbs, res_limbs, + extended_value))) ret = 0; end: @@ -540,6 +664,7 @@ end: /* op2_live_size */ LIMBSOF(op2) + 2, \ /* res_live_size */ LIMBSOF(op1) + LIMBSOF(op2) + ((i - 1) % 4), \ /* check_size */ LIMBSOF(ex), \ + /* extended_limb_value */ EXTENDED_LIMB_ZERO, \ } static struct test_case_st test_mul_cases[] = { @@ -572,18 +697,19 @@ static int test_mul(int i) } /* i should be set to match the iteration number that's displayed when testing */ -#define MUL_TRUNCATED_CASE(i, op1, op2, ex) \ - { \ - /* op1 */ op1, \ - /* op1_size */ LIMBSOF(op1), \ - /* op2 */ op2, \ - /* op2_size */ LIMBSOF(op2), \ - /* ex */ ex, \ - /* ex_size */ LIMBSOF(ex), \ - /* op1_live_size */ LIMBSOF(op1) + 1, \ - /* op2_live_size */ LIMBSOF(op2) + 2, \ - /* res_live_size */ LIMBSOF(ex) / 2, \ - /* check_size */ LIMBSOF(ex) / 2, \ +#define MUL_TRUNCATED_CASE(i, op1, op2, ex) \ + { \ + /* op1 */ op1, \ + /* op1_size */ LIMBSOF(op1), \ + /* op2 */ op2, \ + /* op2_size */ LIMBSOF(op2), \ + /* ex */ ex, \ + /* ex_size */ LIMBSOF(ex), \ + /* op1_live_size */ LIMBSOF(op1) + 1, \ + /* op2_live_size */ LIMBSOF(op2) + 2, \ + /* res_live_size */ LIMBSOF(ex) / 2, \ + /* check_size */ LIMBSOF(ex) / 2, \ + /* extended_limb_value */ EXTENDED_LIMB_ZERO, \ } /* A special case, where the truncation is set to the size of ex minus 64 bits */ #define MUL_TRUNCATED_SPECIAL_CASE1(i, op1, op2, ex) \ @@ -598,6 +724,7 @@ static int test_mul(int i) /* op2_live_size */ LIMBSOF(op2) + 2, \ /* res_live_size */ LIMBSOF(ex) - 8 / OSSL_FN_BYTES, \ /* check_size */ LIMBSOF(ex) - 8 / OSSL_FN_BYTES, \ + /* extended_limb_value */ EXTENDED_LIMB_ZERO, \ } static struct test_case_st test_mul_truncate_cases[] = { @@ -632,7 +759,9 @@ static int test_mul_truncated(int i) int setup_tests(void) { ADD_ALL_TESTS(test_add, 16); + ADD_ALL_TESTS(test_add_truncated, 16); ADD_ALL_TESTS(test_sub, 16); + ADD_ALL_TESTS(test_sub_truncated, 16); ADD_ALL_TESTS(test_mul_feature_r_is_operand, 4); ADD_ALL_TESTS(test_mul, OSSL_NELEM(test_mul_cases)); ADD_ALL_TESTS(test_mul_truncated, OSSL_NELEM(test_mul_truncate_cases));