]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Add a divide rounding up safe math function.
authorPauli <pauli@openssl.org>
Tue, 15 Mar 2022 03:19:07 +0000 (14:19 +1100)
committerPauli <pauli@openssl.org>
Tue, 29 Mar 2022 23:10:25 +0000 (10:10 +1100)
This function takes arguments a & b and computes a / b rounding any
remainder up.

It is safe with respect to overflow and negative inputs.  It's only fast for
non-negative inputs.

Reviewed-by: Matt Caswell <matt@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/17884)

include/internal/safe_math.h
test/safe_math_test.c

index 85c6147e55c81aa882592232452d58a66f0c7d17..d14a090b242af8cc8dc9ec0850f004276eb1f3d2 100644 (file)
         return safe_add_ ## type_name(y, x / c, err);                        \
     }
 
+/*
+ * Calculate a / b rounding up:
+ *     i.e. a / b + (a % b != 0)
+ * Which is usually (less safely) converted to (a + b - 1) / b
+ * If you *know* that b != 0, then it's safe to ignore err.
+ */
+#define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \
+    static ossl_inline ossl_unused type safe_div_round_up_ ## type_name      \
+        (type a, type b, int *errp)                                          \
+    {                                                                        \
+        type x;                                                              \
+        int *err, err_local = 0;                                             \
+                                                                             \
+        /* Allow errors to be ignored by callers */                          \
+        err = errp != NULL ? errp : &err_local;                              \
+        /* Fast path, both positive */                                       \
+        if (b > 0 && a > 0) {                                                \
+            /* Faster path: no overflow concerns */                          \
+            if (a < max - b)                                                 \
+                return (a + b - 1) / b;                                      \
+            return a / b + (a % b != 0);                                     \
+        }                                                                    \
+        if (b == 0) {                                                        \
+            *err |= 1;                                                       \
+            return a == 0 ? 0 : max;                                         \
+        }                                                                    \
+        if (a == 0)                                                          \
+            return 0;                                                        \
+        /* Rather slow path because there are negatives involved */          \
+        x = safe_mod_ ## type_name(a, b, err);                               \
+        return safe_add_ ## type_name(safe_div_ ## type_name(a, b, err),     \
+                                      x != 0, err);                          \
+    }
+
 /* Calculate ranges of types */
 # define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1))
 # define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type))
                         OSSL_SAFE_MATH_MAXS(type))                      \
     OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type),     \
                         OSSL_SAFE_MATH_MAXS(type))                      \
+    OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type,                        \
+                                OSSL_SAFE_MATH_MAXS(type))              \
     OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type))  \
     OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type))     \
     OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type))
     OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type))     \
     OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type))     \
     OSSL_SAFE_MATH_MODU(type_name, type)                                \
+    OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type,                        \
+                                OSSL_SAFE_MATH_MAXU(type))              \
     OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type))  \
     OSSL_SAFE_MATH_NEGU(type_name, type)                                \
     OSSL_SAFE_MATH_ABSU(type_name, type)
index da50ec816bc087ad5265d05917e28003b9045326..ae397151cdb02725e0d3aeab03e98bc4f7030ed1 100644 (file)
@@ -27,28 +27,34 @@ OSSL_SAFE_MATH_UNSIGNED(size_t, size_t)
 
 static const struct {
     int a, b;
-    int sum_err, sub_err, mul_err, div_err, mod_err, neg_a_err, neg_b_err;
-    int abs_a_err, abs_b_err;
-} test_ints[] = {
-    { 1, 3,                 0, 0, 0, 0, 0, 0, 0, 0, 0 },
-    { -1, 3,                0, 0, 0, 0, 0, 0, 0, 0, 0 },
-    { 1, -3,                0, 0, 0, 0, 0, 0, 0, 0, 0 },
-    { -1, -3,               0, 0, 0, 0, 0, 0, 0, 0, 0 },
-    { INT_MAX, 1,           1, 0, 0, 0, 0, 0, 0, 0, 0 },
-    { INT_MAX, 2,           1, 0, 1, 0, 0, 0, 0, 0, 0 },
-    { INT_MIN, 1,           0, 1, 0, 0, 0, 1, 0, 1, 0 },
-    { 1, INT_MIN,           0, 1, 0, 0, 0, 0, 1, 0, 1 },
-    { INT_MIN, 2,           0, 1, 1, 0, 0, 1, 0, 1, 0 },
-    { 2, INT_MIN,           0, 1, 1, 0, 0, 0, 1, 0, 1 },
-    { INT_MIN, -1,          1, 0, 1, 1, 1, 1, 0, 1, 0 },
-    { INT_MAX, INT_MIN,     0, 1, 1, 0, 0, 0, 1, 0, 1 },
-    { INT_MIN, INT_MAX,     0, 1, 1, 0, 0, 1, 0, 1, 0 },
-    { 3, 0,                 0, 0, 0, 1, 1, 0, 0, 0, 0 },
+    int sum_err, sub_err, mul_err, div_err, mod_err, div_round_up_err;
+    int neg_a_err, neg_b_err, abs_a_err, abs_b_err;
+} test_ints[] = {       /*  +  -  *  /  %  /r -a -b |a||b|  */
+    { 1, 3,                 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { -1, 3,                0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 1, -3,                0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { -1, -3,               0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 3, 2,                 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { -3, 2,                0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { 2, -3,                0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { -2, -3,               0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { INT_MAX, 1,           1, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
+    { INT_MAX, 2,           1, 0, 1, 0, 0, 0, 0, 0, 0, 0 },
+    { INT_MAX, 4,           1, 0, 1, 0, 0, 0, 0, 0, 0, 0 },
+    { INT_MAX - 3 , 4,      1, 0, 1, 0, 0, 0, 0, 0, 0, 0 },
+    { INT_MIN, 1,           0, 1, 0, 0, 0, 0, 1, 0, 1, 0 },
+    { 1, INT_MIN,           0, 1, 0, 0, 0, 0, 0, 1, 0, 1 },
+    { INT_MIN, 2,           0, 1, 1, 0, 0, 0, 1, 0, 1, 0 },
+    { 2, INT_MIN,           0, 1, 1, 0, 0, 0, 0, 1, 0, 1 },
+    { INT_MIN, -1,          1, 0, 1, 1, 1, 1, 1, 0, 1, 0 },
+    { INT_MAX, INT_MIN,     0, 1, 1, 0, 0, 0, 0, 1, 0, 1 },
+    { INT_MIN, INT_MAX,     0, 1, 1, 0, 0, 0, 1, 0, 1, 0 },
+    { 3, 0,                 0, 0, 0, 1, 1, 1, 0, 0, 0, 0 },
 };
 
 static int test_int_ops(int n)
 {
-    int err, r;
+    int err, r, s;
     const int a = test_ints[n].a, b = test_ints[n].b;
 
     err = 0;
@@ -81,6 +87,15 @@ static int test_int_ops(int n)
             || (!err && !TEST_int_eq(r, a % b)))
         goto err;
 
+    err = 0;
+    r = safe_div_round_up_int(a, b, &err);
+    if (!TEST_int_eq(err, test_ints[n].div_round_up_err))
+        goto err;
+    s = safe_mod_int(a, b, &err);
+    s = safe_add_int(safe_div_int(a, b, &err), s != 0, &err);
+    if (!err && !TEST_int_eq(r, s))
+        goto err;
+
     err = 0;
     r = safe_neg_int(a, &err);
     if (!TEST_int_eq(err, test_ints[n].neg_a_err)
@@ -112,15 +127,17 @@ static int test_int_ops(int n)
 
 static const struct {
     unsigned int a, b;
-    int sum_err, sub_err, mul_err, div_err, mod_err;
-} test_uints[] = {
-    { 3, 1,                 0, 0, 0, 0, 0 },
-    { 1, 3,                 0, 1, 0, 0, 0 },
-    { UINT_MAX, 1,          1, 0, 0, 0, 0 },
-    { UINT_MAX, 2,          1, 0, 1, 0, 0 },
-    { 1, UINT_MAX,          1, 1, 0, 0, 0 },
-    { 2, UINT_MAX,          1, 1, 1, 0, 0 },
-    { UINT_MAX, 0,          0, 0, 0, 1, 1 },
+    int sum_err, sub_err, mul_err, div_err, mod_err, div_round_up_err;
+} test_uints[] = {      /*  +  -  *  /  %  /r   */
+    { 3, 1,                 0, 0, 0, 0, 0, 0 },
+    { 1, 3,                 0, 1, 0, 0, 0, 0 },
+    { UINT_MAX, 1,          1, 0, 0, 0, 0, 0 },
+    { UINT_MAX, 2,          1, 0, 1, 0, 0, 0 },
+    { UINT_MAX, 16,         1, 0, 1, 0, 0, 0 },
+    { UINT_MAX - 13, 16,    1, 0, 1, 0, 0, 0 },
+    { 1, UINT_MAX,          1, 1, 0, 0, 0, 0 },
+    { 2, UINT_MAX,          1, 1, 1, 0, 0, 0 },
+    { UINT_MAX, 0,          0, 0, 0, 1, 1, 1 },
 };
 
 static int test_uint_ops(int n)
@@ -159,6 +176,12 @@ static int test_uint_ops(int n)
             || (!err && !TEST_uint_eq(r, a % b)))
         goto err;
 
+    err = 0;
+    r = safe_div_round_up_uint(a, b, &err);
+    if (!TEST_int_eq(err, test_uints[n].div_round_up_err)
+            || (!err && !TEST_uint_eq(r, a / b + (a % b != 0))))
+        goto err;
+
     err = 0;
     r = safe_neg_uint(a, &err);
     if (!TEST_int_eq(err, a != 0) || (!err && !TEST_uint_eq(r, 0)))
@@ -186,15 +209,18 @@ static int test_uint_ops(int n)
 
 static const struct {
     size_t a, b;
-    int sum_err, sub_err, mul_err, div_err, mod_err;
+    int sum_err, sub_err, mul_err, div_err, mod_err, div_round_up_err;
 } test_size_ts[] = {
-    { 3, 1,                 0, 0, 0, 0, 0 },
-    { 1, 3,                 0, 1, 0, 0, 0 },
-    { SIZE_MAX, 1,          1, 0, 0, 0, 0 },
-    { SIZE_MAX, 2,          1, 0, 1, 0, 0 },
-    { 1, SIZE_MAX,          1, 1, 0, 0, 0 },
-    { 2, SIZE_MAX,          1, 1, 1, 0, 0 },
-    { 11, 0,                0, 0, 0, 1, 1 },
+    { 3, 1,                 0, 0, 0, 0, 0, 0 },
+    { 1, 3,                 0, 1, 0, 0, 0, 0 },
+    { 36, 8,                0, 0, 0, 0, 0, 0 },
+    { SIZE_MAX, 1,          1, 0, 0, 0, 0, 0 },
+    { SIZE_MAX, 2,          1, 0, 1, 0, 0, 0 },
+    { SIZE_MAX, 8,          1, 0, 1, 0, 0, 0 },
+    { SIZE_MAX - 3, 8,      1, 0, 1, 0, 0, 0 },
+    { 1, SIZE_MAX,          1, 1, 0, 0, 0, 0 },
+    { 2, SIZE_MAX,          1, 1, 1, 0, 0, 0 },
+    { 11, 0,                0, 0, 0, 1, 1, 1 },
 };
 
 static int test_size_t_ops(int n)
@@ -223,7 +249,7 @@ static int test_size_t_ops(int n)
 
     err = 0;
     r = safe_div_size_t(a, b, &err);
-    if (!TEST_int_eq(err, test_uints[n].div_err)
+    if (!TEST_int_eq(err, test_size_ts[n].div_err)
             || (!err && !TEST_size_t_eq(r, a / b)))
         goto err;
 
@@ -233,6 +259,12 @@ static int test_size_t_ops(int n)
             || (!err && !TEST_size_t_eq(r, a % b)))
         goto err;
 
+    err = 0;
+    r = safe_div_round_up_size_t(a, b, &err);
+    if (!TEST_int_eq(err, test_size_ts[n].div_round_up_err)
+            || (!err && !TEST_size_t_eq(r, a / b + (a % b != 0))))
+        goto err;
+
     err = 0;
     r = safe_neg_size_t(a, &err);
     if (!TEST_int_eq(err, a != 0) || (!err && !TEST_size_t_eq(r, 0)))