]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
OSSL_FN: Refactor OSSL_FN_add() and OSSL_FN_sub() for truncation feature/ossl_fn
authorRichard Levitte <levitte@openssl.org>
Wed, 3 Dec 2025 20:21:37 +0000 (21:21 +0100)
committerRichard Levitte <levitte@openssl.org>
Wed, 17 Dec 2025 12:16:00 +0000 (13:16 +0100)
OSSL_FN_mul() set a path that wasn't considered for OSSL_FN_add() and
OSSL_FN_sub(); a truncated result if the result OSSL_FN isn't large
enough to contain the full result.

This is done to keep the OSSL_FN API consistent, with a (tentative)
bonus, that the function calls become more constant time accross
repeated calls with the same size for operands and result.

Reviewed-by: Matt Caswell <matt@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/29309)

crypto/fn/fn_addsub.c
test/fn_api_test.c

index 6c81c3c733075bc717a2f96c1d99c8595b885619..53b05308193625d5ce74fe4c2e16a535364981b3 100644 (file)
@@ -33,33 +33,53 @@ int OSSL_FN_add(OSSL_FN *r, const OSSL_FN *a, const OSSL_FN *b)
     size_t max = a->dsize;
     size_t min = b->dsize;
     size_t rs = r->dsize;
+    const OSSL_FN_ULONG *ap = a->d;
+    const OSSL_FN_ULONG *bp = b->d;
+    OSSL_FN_ULONG *rp = r->d;
 
     /*
-     * 'r' must be able to contain the result.  This is on the caller.
+     * Three stages, after which there's a possible return.
+     * Each stage is limited by the number of remaining limbs
+     * in |r|.
+     *
+     * For each stage, |stage_limbs| is used to hold the number
+     * of limbs being treated in that stage, and |borrow| is
+     * used to transport the borrow from one stage to the other.
      */
-    if (!ossl_assert(max <= rs)) {
-        ERR_raise(ERR_LIB_OSSL_FN, OSSL_FN_R_RESULT_ARG_TOO_SMALL);
-        return 0;
-    }
+    size_t stage_limbs;
+    OSSL_FN_ULONG carry;
 
-    const OSSL_FN_ULONG *ap = a->d;
-    const OSSL_FN_ULONG *bp = b->d;
+    /* Stage 1 */
 
-    OSSL_FN_ULONG *rp = r->d;
-    OSSL_FN_ULONG carry = bn_add_words(rp, ap, bp, (int)min);
+    stage_limbs = (min > rs) ? rs : min;
+    carry = bn_add_words(rp, ap, bp, (int)stage_limbs);
+    if (stage_limbs == rs)
+        return 1;
 
+    /* At this point, we know that |min| limbs have been used so far */
     rp += min;
     ap += min;
 
-    for (size_t dif = max - min; dif > 0; dif--, ap++, rp++) {
+    /* Stage 2 */
+
+    stage_limbs = ((max > rs) ? rs : max) - min;
+
+    for (size_t dif = stage_limbs; dif > 0; dif--, ap++, rp++) {
         OSSL_FN_ULONG t1 = *ap;
         OSSL_FN_ULONG t2 = (t1 + carry) & OSSL_FN_MASK;
 
         *rp = t2;
         carry &= (t2 == 0);
     }
+    if (stage_limbs == rs)
+        return 1;
 
-    for (size_t dif = r->dsize - max; dif > 0; dif--, rp++) {
+    /* Stage 3 */
+
+    /* We know that |max| limbs have been used */
+    stage_limbs = rs - max;
+
+    for (size_t dif = stage_limbs; dif > 0; dif--, rp++) {
         OSSL_FN_ULONG t1 = 0;
         OSSL_FN_ULONG t2 = (t1 + carry) & OSSL_FN_MASK;
 
@@ -76,45 +96,61 @@ int OSSL_FN_sub(OSSL_FN *r, const OSSL_FN *a, const OSSL_FN *b)
     size_t max = (a->dsize >= b->dsize) ? a->dsize : b->dsize;
     size_t min = (a->dsize <= b->dsize) ? a->dsize : b->dsize;
     size_t rs = r->dsize;
-
-    /*
-     * 'r' must be able to contain the result.  This is on the caller.
-     */
-    if (!ossl_assert(max <= rs)) {
-        ERR_raise(ERR_LIB_OSSL_FN, OSSL_FN_R_RESULT_ARG_TOO_SMALL);
-        return 0;
-    }
-
     const OSSL_FN_ULONG *ap = a->d;
     const OSSL_FN_ULONG *bp = b->d;
     OSSL_FN_ULONG *rp = r->d;
-    OSSL_FN_ULONG borrow = bn_sub_words(rp, ap, bp, (int)min);
 
     /*
-     * TODO(FIXNUM): everything following isn't strictly constant-time,
-     * and could use improvement in that regard.
+     * Three stages, after which there's a possible return.
+     * Each stage is limited by the number of remaining limbs
+     * in |r|.
+     *
+     * For each stage, |stage_limbs| is used to hold the number
+     * of limbs being treated in that stage, and |borrow| is
+     * used to transport the borrow from one stage to the other.
      */
+    size_t stage_limbs;
+    OSSL_FN_ULONG borrow;
+
+    /* Stage 1 */
+
+    stage_limbs = (min > rs) ? rs : min;
+    borrow = bn_sub_words(rp, ap, bp, (int)stage_limbs);
+    if (stage_limbs == rs)
+        return 1;
 
+    /* At this point, we know that |min| limbs have been used so far */
     ap += min;
     bp += min;
     rp += min;
 
+    /* Stage 2 */
+
     const OSSL_FN_ULONG *maxp = (a->dsize >= b->dsize) ? ap : bp;
+    const OSSL_FN_ULONG s2_mask1 = (a->dsize >= b->dsize) ? OSSL_FN_MASK : 0;
+    const OSSL_FN_ULONG s2_mask2 = ~s2_mask1;
 
-    /* "sign" borrow, depending on if maxp == ap or maxp == bp */
-    borrow *= (OSSL_FN_ULONG)((a->dsize >= b->dsize) ? 1 : -1);
+    stage_limbs = ((max > rs) ? rs : max) - min;
 
     /* calculate the result of borrowing from more significant limbs */
-    for (size_t dif = max - min; dif > 0; dif--, maxp++, rp++) {
-        OSSL_FN_ULONG t1 = *maxp;
-        OSSL_FN_ULONG t2 = (t1 - borrow) & OSSL_FN_MASK;
+    for (size_t dif = stage_limbs; dif > 0; dif--, maxp++, rp++) {
+        OSSL_FN_ULONG t1 = (*maxp & s2_mask1);
+        OSSL_FN_ULONG t2 = (*maxp & s2_mask2);
+        OSSL_FN_ULONG t3 = (t1 - t2 - borrow) & OSSL_FN_MASK;
 
-        *rp = t2;
-        borrow &= (t1 == 0);
+        *rp = t3;
+        borrow &= (t1 <= t2);
     }
+    if (stage_limbs == rs)
+        return 1;
+
+    /* Stage 3 */
+
+    /* We know that |max| limbs have been used */
+    stage_limbs = rs - max;
 
     /* Finally, fill in the rest of the result array by borrowing from zeros */
-    for (size_t dif = rs - max; dif > 0; dif--, rp++) {
+    for (size_t dif = stage_limbs; dif > 0; dif--, rp++) {
         OSSL_FN_ULONG t1 = 0;
         OSSL_FN_ULONG t2 = (t1 - borrow) & OSSL_FN_MASK;
 
index 880d0b275c54541ac5df390de19606462120a66f..ad0dc2a0283ffce7d6e883f776f64d206242de37 100644 (file)
@@ -41,7 +41,8 @@ static int pollute(OSSL_FN *f, size_t start, size_t end)
     return 1;
 }
 
-static int check_zero(const OSSL_FN *f, size_t start, size_t end)
+static int check_limbs_value(const OSSL_FN *f, size_t start, size_t end,
+    OSSL_FN_ULONG value)
 {
     const OSSL_FN_ULONG *u = ossl_fn_get_words(f);
     size_t l = ossl_fn_get_dsize(f);
@@ -52,8 +53,10 @@ static int check_zero(const OSSL_FN *f, size_t start, size_t end)
         start = end;
 
     for (size_t i = start; i < end; i++)
-        if (u[i] != 0)
+        if (!TEST_size_t_eq(u[i], value)) {
+            TEST_note("start = %zu, end = %zu, i = %zu\n", start, end, i);
             return 0;
+        }
     return 1;
 }
 
@@ -95,6 +98,11 @@ struct test_case_st {
 
     /* Number of limbs to compare the result's OSSL_FN_ULONG array against ex */
     size_t check_size;
+
+    /* When the result is larger than check_size, the expected extended value */
+    OSSL_FN_ULONG extended_limb_value;
+#define EXTENDED_LIMB_ZERO ((OSSL_FN_ULONG)0)
+#define EXTENDED_LIMB_MINUS_ONE ((OSSL_FN_ULONG)-1)
 };
 
 static const OSSL_FN_ULONG ex_add_num0_num0[] = {
@@ -131,18 +139,64 @@ static const OSSL_FN_ULONG ex_add_num3_num3[] = {
     OSSL_FN_ULONG_C(0x1),
 };
 
-#define ADD_CASE(i, op1, op2, ex)         \
-    {                                     \
-        /* op1 */ op1,                    \
-        /* op1_size */ LIMBSOF(op1),      \
-        /* op2 */ op2,                    \
-        /* op2_size */ LIMBSOF(op2),      \
-        /* ex */ ex,                      \
-        /* ex_size */ LIMBSOF(ex),        \
-        /* op1_live_size */ LIMBSOF(op1), \
-        /* op2_live_size */ LIMBSOF(op2), \
-        /* res_live_size */ LIMBSOF(ex),  \
-        /* check_size */ LIMBSOF(ex),     \
+static int test_add_common(struct test_case_st test_case)
+{
+    const OSSL_FN_ULONG *n1 = test_case.op1;
+    size_t n1_limbs = test_case.op1_size;
+    const OSSL_FN_ULONG *n2 = test_case.op2;
+    size_t n2_limbs = test_case.op2_size;
+    const OSSL_FN_ULONG *ex = test_case.ex;
+    size_t n1_new_limbs = test_case.op1_live_size;
+    size_t n2_new_limbs = test_case.op2_live_size;
+    size_t res_limbs = test_case.res_live_size;
+    size_t check_limbs = test_case.check_size;
+    OSSL_FN_ULONG extended_value = test_case.extended_limb_value;
+    int ret = 1;
+    OSSL_FN *fn1 = NULL, *fn2 = NULL, *res = NULL;
+    const OSSL_FN_ULONG *u = NULL;
+
+    /* To test that OSSL_FN_add() does a complete job, 'res' is pre-polluted */
+
+    if (!TEST_ptr(fn1 = OSSL_FN_new_limbs(n1_new_limbs))
+        || !TEST_ptr(fn2 = OSSL_FN_new_limbs(n2_new_limbs))
+        || !TEST_true(ossl_fn_set_words(fn1, n1, n1_limbs))
+        || !TEST_true(ossl_fn_set_words(fn2, n2, n2_limbs))
+        || !TEST_ptr(res = OSSL_FN_new_limbs(res_limbs))
+        || !TEST_true(pollute(res, 0, res_limbs))) {
+        res = 0;
+        /* There's no way to continue tests in this case */
+        goto end;
+    }
+
+    if (!TEST_true(OSSL_FN_add(res, fn1, fn2))
+        || !TEST_ptr(u = ossl_fn_get_words(res))
+        || !TEST_mem_eq(u, check_limbs * OSSL_FN_BYTES,
+            ex, check_limbs * OSSL_FN_BYTES)
+        || !TEST_true(check_limbs_value(res, check_limbs, res_limbs,
+            extended_value)))
+        ret = 0;
+
+end:
+    OSSL_FN_free(fn1);
+    OSSL_FN_free(fn2);
+    OSSL_FN_free(res);
+
+    return ret;
+}
+
+#define ADD_CASE(i, op1, op2, ex)                     \
+    {                                                 \
+        /* op1 */ op1,                                \
+        /* op1_size */ LIMBSOF(op1),                  \
+        /* op2 */ op2,                                \
+        /* op2_size */ LIMBSOF(op2),                  \
+        /* ex */ ex,                                  \
+        /* ex_size */ LIMBSOF(ex),                    \
+        /* op1_live_size */ LIMBSOF(op1) + 1,         \
+        /* op2_live_size */ LIMBSOF(op2) + 2,         \
+        /* res_live_size */ LIMBSOF(ex) + 3,          \
+        /* check_size */ LIMBSOF(ex),                 \
+        /* extended_limb_value */ EXTENDED_LIMB_ZERO, \
     }
 
 static struct test_case_st test_add_cases[] = {
@@ -166,35 +220,46 @@ static struct test_case_st test_add_cases[] = {
 
 static int test_add(int i)
 {
-    const OSSL_FN_ULONG *n1 = test_add_cases[i].op1;
-    size_t n1_limbs = test_add_cases[i].op1_size;
-    const OSSL_FN_ULONG *n2 = test_add_cases[i].op2;
-    size_t n2_limbs = test_add_cases[i].op2_size;
-    const OSSL_FN_ULONG *ex = test_add_cases[i].ex;
-    size_t ex_limbs = test_add_cases[i].ex_size;
-    size_t check_limbs = test_add_cases[i].check_size;
-    int ret = 1;
-    OSSL_FN *fn1 = NULL, *fn2 = NULL, *res = NULL;
-    const OSSL_FN_ULONG *u = NULL;
-
-    /* To test that OSSL_FN_add() does a complete job, 'res' is pre-polluted */
+    return test_add_common(test_add_cases[i]);
+}
 
-    if (!TEST_ptr(fn1 = OSSL_FN_new_limbs(n1_limbs))
-        || !TEST_ptr(fn2 = OSSL_FN_new_limbs(n2_limbs))
-        || !TEST_ptr(res = OSSL_FN_new_limbs(ex_limbs))
-        || !TEST_true(pollute(res, 0, ex_limbs))
-        || !TEST_true(ossl_fn_set_words(fn1, n1, n1_limbs))
-        || !TEST_true(ossl_fn_set_words(fn2, n2, n2_limbs))
-        || !TEST_true(OSSL_FN_add(res, fn1, fn2))
-        || !TEST_ptr(u = ossl_fn_get_words(res))
-        || !TEST_mem_eq(u, ossl_fn_get_dsize(res) * OSSL_FN_BYTES,
-            ex, check_limbs * OSSL_FN_BYTES))
-        ret = 0;
-    OSSL_FN_free(fn1);
-    OSSL_FN_free(fn2);
-    OSSL_FN_free(res);
+#define ADD_TRUNCATED_CASE(i, op1, op2, ex)           \
+    {                                                 \
+        /* op1 */ op1,                                \
+        /* op1_size */ LIMBSOF(op1),                  \
+        /* op2 */ op2,                                \
+        /* op2_size */ LIMBSOF(op2),                  \
+        /* ex */ ex,                                  \
+        /* ex_size */ LIMBSOF(ex),                    \
+        /* op1_live_size */ LIMBSOF(op1) + 1,         \
+        /* op2_live_size */ LIMBSOF(op2) + 2,         \
+        /* res_live_size */ LIMBSOF(ex) - 1,          \
+        /* check_size */ LIMBSOF(ex) - 1,             \
+        /* extended_limb_value */ EXTENDED_LIMB_ZERO, \
+    }
 
-    return ret;
+static struct test_case_st test_add_truncated_cases[] = {
+    ADD_TRUNCATED_CASE(1, num0, num0, ex_add_num0_num0),
+    ADD_TRUNCATED_CASE(2, num0, num1, ex_add_num0_num1),
+    ADD_TRUNCATED_CASE(3, num0, num2, ex_add_num0_num2),
+    ADD_TRUNCATED_CASE(4, num0, num3, ex_add_num0_num3),
+    ADD_TRUNCATED_CASE(5, num1, num0, ex_add_num0_num1), /* Commutativity check */
+    ADD_TRUNCATED_CASE(6, num1, num1, ex_add_num1_num1),
+    ADD_TRUNCATED_CASE(7, num1, num2, ex_add_num1_num2),
+    ADD_TRUNCATED_CASE(8, num1, num3, ex_add_num1_num3),
+    ADD_TRUNCATED_CASE(9, num2, num0, ex_add_num0_num2), /* Commutativity check */
+    ADD_TRUNCATED_CASE(10, num2, num1, ex_add_num1_num2), /* Commutativity check */
+    ADD_TRUNCATED_CASE(11, num2, num2, ex_add_num2_num2),
+    ADD_TRUNCATED_CASE(12, num2, num3, ex_add_num2_num3),
+    ADD_TRUNCATED_CASE(13, num3, num0, ex_add_num0_num3), /* Commutativity check */
+    ADD_TRUNCATED_CASE(14, num3, num1, ex_add_num1_num3), /* Commutativity check */
+    ADD_TRUNCATED_CASE(15, num3, num2, ex_add_num2_num3), /* Commutativity check */
+    ADD_TRUNCATED_CASE(16, num3, num3, ex_add_num3_num3),
+};
+
+static int test_add_truncated(int i)
+{
+    return test_add_common(test_add_truncated_cases[i]);
 }
 
 static const OSSL_FN_ULONG ex_sub_num0_num0[] = {
@@ -252,65 +317,44 @@ static const OSSL_FN_ULONG ex_sub_num3_num3[] = {
     OSSL_FN_ULONG64_C(0x00000000, 0x00000000),
 };
 
-#define SUB_CASE(i, op1, op2, ex)         \
-    {                                     \
-        /* op1, with size */ op1,         \
-        LIMBSOF(op1),                     \
-        /* op2, with size */ op2,         \
-        LIMBSOF(op2),                     \
-        /* ex, with size */ ex,           \
-        LIMBSOF(ex),                      \
-        /* op1_live_size */ LIMBSOF(op1), \
-        /* op2_live_size */ LIMBSOF(op2), \
-        /* res_live_size */ LIMBSOF(ex),  \
-        /* check_size */ LIMBSOF(ex),     \
-    }
-
-static struct test_case_st test_sub_cases[] = {
-    SUB_CASE(1, num0, num0, ex_sub_num0_num0),
-    SUB_CASE(2, num0, num1, ex_sub_num0_num1),
-    SUB_CASE(3, num0, num2, ex_sub_num0_num2),
-    SUB_CASE(4, num0, num3, ex_sub_num0_num3),
-    SUB_CASE(5, num1, num0, ex_sub_num1_num0),
-    SUB_CASE(6, num1, num1, ex_sub_num1_num1),
-    SUB_CASE(7, num1, num2, ex_sub_num1_num2),
-    SUB_CASE(8, num1, num3, ex_sub_num1_num3),
-    SUB_CASE(9, num2, num0, ex_sub_num2_num0),
-    SUB_CASE(10, num2, num1, ex_sub_num2_num1),
-    SUB_CASE(11, num2, num2, ex_sub_num2_num2),
-    SUB_CASE(12, num2, num3, ex_sub_num2_num3),
-    SUB_CASE(13, num3, num0, ex_sub_num3_num0),
-    SUB_CASE(14, num3, num1, ex_sub_num3_num1),
-    SUB_CASE(15, num3, num2, ex_sub_num3_num2),
-    SUB_CASE(16, num3, num3, ex_sub_num3_num3),
-};
-
-static int test_sub(int i)
+static int test_sub_common(struct test_case_st test_case)
 {
-    const OSSL_FN_ULONG *n1 = test_sub_cases[i].op1;
-    size_t n1_limbs = test_sub_cases[i].op1_size;
-    const OSSL_FN_ULONG *n2 = test_sub_cases[i].op2;
-    size_t n2_limbs = test_sub_cases[i].op2_size;
-    const OSSL_FN_ULONG *ex = test_sub_cases[i].ex;
-    size_t ex_limbs = test_sub_cases[i].ex_size;
-    size_t check_limbs = test_sub_cases[i].check_size;
+    const OSSL_FN_ULONG *n1 = test_case.op1;
+    size_t n1_limbs = test_case.op1_size;
+    const OSSL_FN_ULONG *n2 = test_case.op2;
+    size_t n2_limbs = test_case.op2_size;
+    const OSSL_FN_ULONG *ex = test_case.ex;
+    size_t n1_new_limbs = test_case.op1_live_size;
+    size_t n2_new_limbs = test_case.op2_live_size;
+    size_t res_limbs = test_case.res_live_size;
+    size_t check_limbs = test_case.check_size;
+    OSSL_FN_ULONG extended_value = test_case.extended_limb_value;
     int ret = 1;
     OSSL_FN *fn1 = NULL, *fn2 = NULL, *res = NULL;
     const OSSL_FN_ULONG *u = NULL;
 
     /* To test that OSSL_FN_sub() does a complete job, 'res' is pre-polluted */
 
-    if (!TEST_ptr(fn1 = OSSL_FN_new_limbs(n1_limbs))
-        || !TEST_ptr(fn2 = OSSL_FN_new_limbs(n2_limbs))
-        || !TEST_ptr(res = OSSL_FN_new_limbs(ex_limbs))
-        || !TEST_true(pollute(res, 0, ex_limbs))
+    if (!TEST_ptr(fn1 = OSSL_FN_new_limbs(n1_new_limbs))
+        || !TEST_ptr(fn2 = OSSL_FN_new_limbs(n2_new_limbs))
         || !TEST_true(ossl_fn_set_words(fn1, n1, n1_limbs))
         || !TEST_true(ossl_fn_set_words(fn2, n2, n2_limbs))
-        || !TEST_true(OSSL_FN_sub(res, fn1, fn2))
+        || !TEST_ptr(res = OSSL_FN_new_limbs(res_limbs))
+        || !TEST_true(pollute(res, 0, res_limbs))) {
+        res = 0;
+        /* There's no way to continue tests in this case */
+        goto end;
+    }
+
+    if (!TEST_true(OSSL_FN_sub(res, fn1, fn2))
         || !TEST_ptr(u = ossl_fn_get_words(res))
-        || !TEST_mem_eq(u, ossl_fn_get_dsize(res) * OSSL_FN_BYTES,
-            ex, check_limbs * OSSL_FN_BYTES))
+        || !TEST_mem_eq(u, check_limbs * OSSL_FN_BYTES,
+            ex, check_limbs * OSSL_FN_BYTES)
+        || !TEST_true(check_limbs_value(res, check_limbs, res_limbs,
+            extended_value)))
         ret = 0;
+
+end:
     OSSL_FN_free(fn1);
     OSSL_FN_free(fn2);
     OSSL_FN_free(res);
@@ -318,6 +362,84 @@ static int test_sub(int i)
     return ret;
 }
 
+#define SUB_CASE(i, op1, op2, ex, ext)        \
+    {                                         \
+        /* op1 */ op1,                        \
+        /* op1_size */ LIMBSOF(op1),          \
+        /* op2 */ op2,                        \
+        /* op2_size */ LIMBSOF(op2),          \
+        /* ex */ ex,                          \
+        /* ex_size */ LIMBSOF(ex),            \
+        /* op1_live_size */ LIMBSOF(op1) + 1, \
+        /* op2_live_size */ LIMBSOF(op2) + 2, \
+        /* res_live_size */ LIMBSOF(ex) + 3,  \
+        /* check_size */ LIMBSOF(ex),         \
+        /* extended_limb_value */ (ext),      \
+    }
+
+static struct test_case_st test_sub_cases[] = {
+    SUB_CASE(1, num0, num0, ex_sub_num0_num0, EXTENDED_LIMB_ZERO),
+    SUB_CASE(2, num0, num1, ex_sub_num0_num1, EXTENDED_LIMB_ZERO),
+    SUB_CASE(3, num0, num2, ex_sub_num0_num2, EXTENDED_LIMB_ZERO),
+    SUB_CASE(4, num0, num3, ex_sub_num0_num3, EXTENDED_LIMB_MINUS_ONE),
+    SUB_CASE(5, num1, num0, ex_sub_num1_num0, EXTENDED_LIMB_MINUS_ONE),
+    SUB_CASE(6, num1, num1, ex_sub_num1_num1, EXTENDED_LIMB_ZERO),
+    SUB_CASE(7, num1, num2, ex_sub_num1_num2, EXTENDED_LIMB_MINUS_ONE),
+    SUB_CASE(8, num1, num3, ex_sub_num1_num3, EXTENDED_LIMB_MINUS_ONE),
+    SUB_CASE(9, num2, num0, ex_sub_num2_num0, EXTENDED_LIMB_MINUS_ONE),
+    SUB_CASE(10, num2, num1, ex_sub_num2_num1, EXTENDED_LIMB_ZERO),
+    SUB_CASE(11, num2, num2, ex_sub_num2_num2, EXTENDED_LIMB_ZERO),
+    SUB_CASE(12, num2, num3, ex_sub_num2_num3, EXTENDED_LIMB_MINUS_ONE),
+    SUB_CASE(13, num3, num0, ex_sub_num3_num0, EXTENDED_LIMB_ZERO),
+    SUB_CASE(14, num3, num1, ex_sub_num3_num1, EXTENDED_LIMB_ZERO),
+    SUB_CASE(15, num3, num2, ex_sub_num3_num2, EXTENDED_LIMB_ZERO),
+    SUB_CASE(16, num3, num3, ex_sub_num3_num3, EXTENDED_LIMB_ZERO),
+};
+
+static int test_sub(int i)
+{
+    return test_sub_common(test_sub_cases[i]);
+}
+
+#define SUB_TRUNCATED_CASE(i, op1, op2, ex)           \
+    {                                                 \
+        /* op1 */ op1,                                \
+        /* op1_size */ LIMBSOF(op1),                  \
+        /* op2 */ op2,                                \
+        /* op2_size */ LIMBSOF(op2),                  \
+        /* ex */ ex,                                  \
+        /* ex_size */ LIMBSOF(ex),                    \
+        /* op1_live_size */ LIMBSOF(op1) + 1,         \
+        /* op2_live_size */ LIMBSOF(op2) + 2,         \
+        /* res_live_size */ LIMBSOF(ex) - 1,          \
+        /* check_size */ LIMBSOF(ex) - 1,             \
+        /* extended_limb_value */ EXTENDED_LIMB_ZERO, \
+    }
+
+static struct test_case_st test_sub_truncated_cases[] = {
+    SUB_TRUNCATED_CASE(1, num0, num0, ex_sub_num0_num0),
+    SUB_TRUNCATED_CASE(2, num0, num1, ex_sub_num0_num1),
+    SUB_TRUNCATED_CASE(3, num0, num2, ex_sub_num0_num2),
+    SUB_TRUNCATED_CASE(4, num0, num3, ex_sub_num0_num3),
+    SUB_TRUNCATED_CASE(5, num1, num0, ex_sub_num1_num0),
+    SUB_TRUNCATED_CASE(6, num1, num1, ex_sub_num1_num1),
+    SUB_TRUNCATED_CASE(7, num1, num2, ex_sub_num1_num2),
+    SUB_TRUNCATED_CASE(8, num1, num3, ex_sub_num1_num3),
+    SUB_TRUNCATED_CASE(9, num2, num0, ex_sub_num2_num0),
+    SUB_TRUNCATED_CASE(10, num2, num1, ex_sub_num2_num1),
+    SUB_TRUNCATED_CASE(11, num2, num2, ex_sub_num2_num2),
+    SUB_TRUNCATED_CASE(12, num2, num3, ex_sub_num2_num3),
+    SUB_TRUNCATED_CASE(13, num3, num0, ex_sub_num3_num0),
+    SUB_TRUNCATED_CASE(14, num3, num1, ex_sub_num3_num1),
+    SUB_TRUNCATED_CASE(15, num3, num2, ex_sub_num3_num2),
+    SUB_TRUNCATED_CASE(16, num3, num3, ex_sub_num3_num3),
+};
+
+static int test_sub_truncated(int i)
+{
+    return test_sub_common(test_sub_truncated_cases[i]);
+}
+
 /* A set of expected results, also in OSSL_FN_ULONG array form */
 static const OSSL_FN_ULONG ex_mul_num0_num0[] = {
     OSSL_FN_ULONG64_C(0x00000000, 0x00000001),
@@ -493,6 +615,7 @@ static int test_mul_common(struct test_case_st test_case)
     size_t n2_new_limbs = test_case.op2_live_size;
     size_t res_limbs = test_case.res_live_size;
     size_t check_limbs = test_case.check_size;
+    OSSL_FN_ULONG extended_value = test_case.extended_limb_value;
     OSSL_FN *fn1 = NULL, *fn2 = NULL, *res = NULL;
     const OSSL_FN_ULONG *u = NULL;
 
@@ -515,7 +638,8 @@ static int test_mul_common(struct test_case_st test_case)
         || !TEST_ptr(u = ossl_fn_get_words(res))
         || !TEST_mem_eq(u, check_limbs * OSSL_FN_BYTES,
             ex, check_limbs * OSSL_FN_BYTES)
-        || !TEST_true(check_zero(res, check_limbs, res_limbs)))
+        || !TEST_true(check_limbs_value(res, check_limbs, res_limbs,
+            extended_value)))
         ret = 0;
 
 end:
@@ -540,6 +664,7 @@ end:
         /* op2_live_size */ LIMBSOF(op2) + 2,                            \
         /* res_live_size */ LIMBSOF(op1) + LIMBSOF(op2) + ((i - 1) % 4), \
         /* check_size */ LIMBSOF(ex),                                    \
+        /* extended_limb_value */ EXTENDED_LIMB_ZERO,                    \
     }
 
 static struct test_case_st test_mul_cases[] = {
@@ -572,18 +697,19 @@ static int test_mul(int i)
 }
 
 /* i should be set to match the iteration number that's displayed when testing */
-#define MUL_TRUNCATED_CASE(i, op1, op2, ex)   \
-    {                                         \
-        /* op1 */ op1,                        \
-        /* op1_size */ LIMBSOF(op1),          \
-        /* op2 */ op2,                        \
-        /* op2_size */ LIMBSOF(op2),          \
-        /* ex */ ex,                          \
-        /* ex_size */ LIMBSOF(ex),            \
-        /* op1_live_size */ LIMBSOF(op1) + 1, \
-        /* op2_live_size */ LIMBSOF(op2) + 2, \
-        /* res_live_size */ LIMBSOF(ex) / 2,  \
-        /* check_size */ LIMBSOF(ex) / 2,     \
+#define MUL_TRUNCATED_CASE(i, op1, op2, ex)           \
+    {                                                 \
+        /* op1 */ op1,                                \
+        /* op1_size */ LIMBSOF(op1),                  \
+        /* op2 */ op2,                                \
+        /* op2_size */ LIMBSOF(op2),                  \
+        /* ex */ ex,                                  \
+        /* ex_size */ LIMBSOF(ex),                    \
+        /* op1_live_size */ LIMBSOF(op1) + 1,         \
+        /* op2_live_size */ LIMBSOF(op2) + 2,         \
+        /* res_live_size */ LIMBSOF(ex) / 2,          \
+        /* check_size */ LIMBSOF(ex) / 2,             \
+        /* extended_limb_value */ EXTENDED_LIMB_ZERO, \
     }
 /* A special case, where the truncation is set to the size of ex minus 64 bits */
 #define MUL_TRUNCATED_SPECIAL_CASE1(i, op1, op2, ex)         \
@@ -598,6 +724,7 @@ static int test_mul(int i)
         /* op2_live_size */ LIMBSOF(op2) + 2,                \
         /* res_live_size */ LIMBSOF(ex) - 8 / OSSL_FN_BYTES, \
         /* check_size */ LIMBSOF(ex) - 8 / OSSL_FN_BYTES,    \
+        /* extended_limb_value */ EXTENDED_LIMB_ZERO,        \
     }
 
 static struct test_case_st test_mul_truncate_cases[] = {
@@ -632,7 +759,9 @@ static int test_mul_truncated(int i)
 int setup_tests(void)
 {
     ADD_ALL_TESTS(test_add, 16);
+    ADD_ALL_TESTS(test_add_truncated, 16);
     ADD_ALL_TESTS(test_sub, 16);
+    ADD_ALL_TESTS(test_sub_truncated, 16);
     ADD_ALL_TESTS(test_mul_feature_r_is_operand, 4);
     ADD_ALL_TESTS(test_mul, OSSL_NELEM(test_mul_cases));
     ADD_ALL_TESTS(test_mul_truncated, OSSL_NELEM(test_mul_truncate_cases));