const uint8_t *in, size_t in_len,
uint32_t gamma1);
-/*
- * @brief Reduces x mod q in constant time
+/*-
+ * @brief Reduces 0 <= x < 2*q, mod q.
* i.e. return x < q ? x : x - q;
*
- * @param x Where x is assumed to be in the range 0 <= x < 2*q
+ * Subtract |q| if the input is larger, without exposing a side-channel,
+ * avoiding the "clangover" attack. See |constish_time_true| for a discussion
+ * on why the value barrier is by default omitted.
+ *
* @returns the difference in the range 0..q-1
*/
-static ossl_inline ossl_unused uint32_t reduce_once(uint32_t x)
+static ossl_inline ossl_unused __owur uint32_t reduce_once(uint32_t x)
{
- return constant_time_select_32(constant_time_lt_32(x, ML_DSA_Q), x, x - ML_DSA_Q);
+ const uint32_t subtracted = x - ML_DSA_Q;
+ uint32_t mask = constish_time_true(subtracted >> 31);
+
+ return (mask & x) | (~mask & subtracted);
}
/*
- * @brief Calculate The positive value of (a-b) mod q in constant time.
+ * @brief Calculates the positive value of (a-b) mod q in constant time.
*
* a - b mod q gives a value in the range -(q-1)..(q-1)
* By adding q we get a range of 1..(2q-1).
/*
* @brief Returns the absolute value in constant time.
- * i.e. return is_positive(x) ? x : -x;
+ * i.e. return is_negative(x) ? -x : x;
*/
static ossl_inline ossl_unused uint32_t abs_signed(uint32_t x)
{
- return constant_time_select_32(constant_time_lt_32(x, 0x80000000), x, 0u - x);
+ uint32_t mask = 0u - (x >> 31);
+
+ return constant_time_select_32(mask, 0u - x, x);
}
/*
* @brief Returns the absolute value modulo q in constant time
- * i.e return x > (q - 1) / 2 ? q - x : x;
+ * i.e return x <= (q-1)/2 ? x : q - x;
*/
static ossl_inline ossl_unused uint32_t abs_mod_prime(uint32_t x)
{
- return constant_time_select_32(constant_time_lt_32(ML_DSA_Q_MINUS1_DIV2, x),
- ML_DSA_Q - x, x);
+ uint32_t mask = x - ML_DSA_Q_MINUS1_DIV2;
+
+ mask = 0u - (mask >> 31);
+ return constant_time_select_32(mask, x, ML_DSA_Q - x);
}
/*
*/
static ossl_inline ossl_unused uint32_t maximum(uint32_t x, uint32_t y)
{
- return constant_time_select_int(constant_time_lt(x, y), y, x);
+ uint32_t mask = x - y;
+ mask = 0u - (mask >> 31);
+ return constant_time_select_int(mask, y, x);
}
#endif /* OSSL_CRYPTO_ML_DSA_LOCAL_H */