return result;
}
-static uint128_t add_128_128(uint128_t a, uint128_t b)
+/* Calculate addition with overflow checking. Returns true on wrap-around,
+ * false otherwise.
+ */
+static bool check_add_128_128_overflow(uint128_t *result, uint128_t a,
+ uint128_t b)
{
- uint128_t result;
+ bool carry;
- result.m_low = a.m_low + b.m_low;
- result.m_high = a.m_high + b.m_high + (result.m_low < a.m_low);
+ result->m_low = a.m_low + b.m_low;
+ carry = (result->m_low < a.m_low);
- return result;
+ result->m_high = a.m_high + b.m_high + carry;
+
+ /* Using constant-time bitwise arithmetic to prevent timing
+ * side-channels.
+ */
+ carry = (result->m_high < a.m_high) |
+ ((result->m_high == a.m_high) & carry);
+
+ return carry;
}
static void vli_mult(u64 *result, const u64 *left, const u64 *right,
uint128_t product;
product = mul_64_64(left[i], right[k - i]);
-
- r01 = add_128_128(r01, product);
- r2 += (r01.m_high < product.m_high);
+ r2 += check_add_128_128_overflow(&r01, r01, product);
}
result[k] = r01.m_low;
uint128_t product;
product = mul_64_64(left[k], right);
- r01 = add_128_128(r01, product);
+ check_add_128_128_overflow(&r01, r01, product);
/* no carry */
result[k] = r01.m_low;
r01.m_low = r01.m_high;
product.m_low <<= 1;
}
- r01 = add_128_128(r01, product);
- r2 += (r01.m_high < product.m_high);
+ r2 += check_add_128_128_overflow(&r01, r01, product);
}
result[k] = r01.m_low;