* https://www.openssl.org/source/license.html
*/
+#include <openssl/byteorder.h>
#include <assert.h>
#include "ml_dsa_local.h"
#include "ml_dsa_key.h"
#include "ml_dsa_sign.h"
#include "internal/packet.h"
+/* Cast mod_sub result in support of left-shifts that create 64-bit values. */
+#define mod_sub_64(a, b) ((uint64_t) mod_sub(a, b))
+
typedef int (ENCODE_FN)(const POLY *s, WPACKET *pkt);
typedef int (DECODE_FN)(POLY *s, PACKET *pkt);
if (!WPACKET_allocate_bytes(pkt, 32 * 4, &out))
return 0;
- while (in < end) {
+ do {
uint32_t z0 = *in++;
uint32_t z1 = *in++;
*out++ = z0 | (z1 << 4);
- }
+ } while (in < end);
return 1;
}
uint8_t *out;
const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
- if (!WPACKET_allocate_bytes(pkt, 32 * 3, &out))
+ if (!WPACKET_allocate_bytes(pkt, 32 * 6, &out))
return 0;
- while (in < end) {
+ do {
uint32_t c0 = *in++;
uint32_t c1 = *in++;
uint32_t c2 = *in++;
uint32_t c3 = *in++;
*out++ = c0 | (c1 << 6);
- *out++ = c1 >> 4 | (c2 << 4);
- *out++ = c3;
- }
+ *out++ = (c1 >> 2) | (c2 << 4);
+ *out++ = (c2 >> 4) | (c3 << 2);
+ } while (in < end);
return 1;
}
if (!WPACKET_allocate_bytes(pkt, 32 * 10, &out))
return 0;
- while (in < end) {
+ do {
uint32_t c0 = *in++;
uint32_t c1 = *in++;
uint32_t c2 = *in++;
*out++ = (uint8_t)((c1 >> 6) | (c2 << 4));
*out++ = (uint8_t)((c2 >> 4) | (c3 << 6));
*out++ = (uint8_t)(c3 >> 2);
- }
+ } while (in < end);
return 1;
}
*/
static int poly_decode_10_bits(POLY *p, PACKET *pkt)
{
- int ret = 0;
const uint8_t *in = NULL;
- uint32_t v, mask = 0x3ff; /* 10 bits */
+ uint32_t v, w, mask = 0x3ff; /* 10 bits */
uint32_t *out = p->coeff, *end = out + ML_DSA_NUM_POLY_COEFFICIENTS;
do {
if (!PACKET_get_bytes(pkt, &in, 5))
- goto err;
- /* put first 4 bytes into v, 5th byte is accessed directly as in[4] */
- memcpy(&v, in, 4);
+ return 0;
+
+ in = OSSL_CRYPTO_load_u32_le(&v, in);
+ w = *in;
+
*out++ = v & mask;
*out++ = (v >> 10) & mask;
*out++ = (v >> 20) & mask;
- *out++ = (v >> 30) | (((uint32_t)in[4]) << 2);
+ *out++ = (v >> 30) | (w << 2);
} while (out < end);
- ret = 1;
-err:
- return ret;
+ return 1;
}
/*
if (!WPACKET_allocate_bytes(pkt, 32 * 4, &out))
return 0;
- while (in < end) {
- uint32_t z0 = mod_sub(4, *in++); /* 0..8 */
- uint32_t z1 = mod_sub(4, *in++); /* 0..8 */
+ do {
+ uint32_t z = mod_sub(4, *in++);
- *out++ = z0 | (z1 << 4);
- }
+ *out++ = z | (mod_sub(4, *in++) << 4);
+ } while (in < end);
return 1;
}
for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) {
if (!PACKET_get_bytes(pkt, &in, 4))
goto err;
- memcpy(&v, in, 4);
+ in = OSSL_CRYPTO_load_u32_le(&v, in);
/*
* None of the nibbles may be >= 9. So if the MSB of any nibble is set,
if (!WPACKET_allocate_bytes(pkt, 32 * 3, &out))
return 0;
- while (in < end) {
- uint32_t z0 = mod_sub(2, *in++); /* 0..7 */
- uint32_t z1 = mod_sub(2, *in++); /* 0..7 */
- uint32_t z2 = mod_sub(2, *in++); /* 0..7 */
- uint32_t z3 = mod_sub(2, *in++); /* 0..7 */
- uint32_t z4 = mod_sub(2, *in++); /* 0..7 */
- uint32_t z5 = mod_sub(2, *in++); /* 0..7 */
- uint32_t z6 = mod_sub(2, *in++); /* 0..7 */
- uint32_t z7 = mod_sub(2, *in++); /* 0..7 */
-
- *out++ = (uint8_t)z0 | (uint8_t)(z1 << 3) | (uint8_t)(z2 << 6);
- *out++ = (uint8_t)(z2 >> 2) | (uint8_t)(z3 << 1) | (uint8_t)(z4 << 4) | (uint8_t)(z5 << 7);
- *out++ = (uint8_t)(z5 >> 1) | (uint8_t)(z6 << 2) | (uint8_t)(z7 << 5);
- }
+ do {
+ uint32_t z;
+
+ z = mod_sub(2, *in++);
+ z |= mod_sub(2, *in++) << 3;
+ z |= mod_sub(2, *in++) << 6;
+ z |= mod_sub(2, *in++) << 9;
+ z |= mod_sub(2, *in++) << 12;
+ z |= mod_sub(2, *in++) << 15;
+ z |= mod_sub(2, *in++) << 18;
+ z |= mod_sub(2, *in++) << 21;
+
+ out = OSSL_CRYPTO_store_u16_le(out, (uint16_t) z);
+ *out++ = (uint8_t) (z >> 16);
+ } while (in < end);
return 1;
}
static int poly_decode_signed_2(POLY *p, PACKET *pkt)
{
int i, ret = 0;
- uint32_t v = 0, *out = p->coeff;
+ uint32_t u = 0, v = 0, *out = p->coeff;
uint32_t msbs, mask;
const uint8_t *in;
for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) {
if (!PACKET_get_bytes(pkt, &in, 3))
goto err;
- memcpy(&v, in, 3);
+ memcpy(&u, in, 3);
+ OSSL_CRYPTO_load_u32_le(&v, (uint8_t *)&u);
+
/*
* Each octal value (3 bits) must be <= 4, So if the MSB is set then the
* bottom 2 bits must not be set.
static const uint32_t range = 1u << 12;
const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
- while (in < end) {
- uint64_t z0 = mod_sub(range, *in++); /* < 2^13 */
- uint64_t z1 = mod_sub(range, *in++);
- uint64_t z2 = mod_sub(range, *in++);
- uint64_t z3 = mod_sub(range, *in++);
- uint64_t z4 = mod_sub(range, *in++);
- uint64_t z5 = mod_sub(range, *in++);
- uint64_t z6 = mod_sub(range, *in++);
- uint64_t z7 = mod_sub(range, *in++);
- uint64_t a1 = (z0) | (z1 << 13) | (z2 << 26) | (z3 << 39) | (z4 << 52);
- uint64_t a2 = (z4 >> 12) | (z5 << 1) | (z6 << 14) | (z7 << 27);
-
- if (!WPACKET_memcpy(pkt, &a1, 8)
- || !WPACKET_memcpy(pkt, &a2, 5))
+ do {
+ uint8_t *out;
+ uint64_t a1, a2;
+
+ if (!WPACKET_allocate_bytes(pkt, 13, &out))
return 0;
- }
+
+ a1 = mod_sub_64(range, *in++);
+ a1 |= mod_sub_64(range, *in++) << 13;
+ a1 |= mod_sub_64(range, *in++) << 26;
+ a1 |= mod_sub_64(range, *in++) << 39;
+ a1 |= (a2 = mod_sub_64(range, *in++)) << 52;
+ a2 = (a2 >> 12) | (mod_sub_64(range, *in++) << 1);
+ a2 |= mod_sub_64(range, *in++) << 14;
+ a2 |= mod_sub_64(range, *in++) << 27;
+
+ out = OSSL_CRYPTO_store_u64_le(out, a1);
+ out = OSSL_CRYPTO_store_u32_le(out, (uint32_t) a2);
+ *out = (uint8_t) (a2 >> 32);
+ } while (in < end);
return 1;
}
static int poly_decode_signed_two_to_power_12(POLY *p, PACKET *pkt)
{
int i, ret = 0;
- uint64_t a1 = 0, a2 = 0;
uint32_t *out = p->coeff;
const uint8_t *in;
static const uint32_t range = 1u << 12;
static const uint32_t mask_13_bits = (1u << 13) - 1;
for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 8); i++) {
+ uint64_t a1;
+ uint32_t a2, b13;
+
if (!PACKET_get_bytes(pkt, &in, 13))
goto err;
- memcpy(&a1, in, 8);
- memcpy(&a2, in + 8, 5);
+ in = OSSL_CRYPTO_load_u64_le(&a1, in);
+ in = OSSL_CRYPTO_load_u32_le(&a2, in);
+ b13 = (uint32_t) *in;
*out++ = mod_sub(range, a1 & mask_13_bits);
*out++ = mod_sub(range, (a1 >> 13) & mask_13_bits);
*out++ = mod_sub(range, (a1 >> 52) | ((a2 << 12) & mask_13_bits));
*out++ = mod_sub(range, (a2 >> 1) & mask_13_bits);
*out++ = mod_sub(range, (a2 >> 14) & mask_13_bits);
- *out++ = mod_sub(range, (a2 >> 27) & mask_13_bits);
+ *out++ = mod_sub(range, (a2 >> 27) | (b13 << 5));
}
ret = 1;
err:
static const uint32_t range = 1u << 19;
const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
- while (in < end) {
- uint32_t z0 = mod_sub(range, *in++); /* < 2^20 */
- uint32_t z1 = mod_sub(range, *in++);
- uint32_t z2 = mod_sub(range, *in++);
- uint32_t z3 = mod_sub(range, *in++);
-
- z0 |= (z1 << 20);
- z1 >>= 12;
- z1 |= (z2 << 8) | (z3 << 28);
- z3 >>= 4;
+ do {
+ uint32_t z0, z1, z2;
+ uint8_t *out;
- if (!WPACKET_memcpy(pkt, &z0, sizeof(z0))
- || !WPACKET_memcpy(pkt, &z1, sizeof(z1))
- || !WPACKET_memcpy(pkt, &z3, 2))
+ if (!WPACKET_allocate_bytes(pkt, 10, &out))
return 0;
- }
+
+ z0 = mod_sub(range, *in++);
+ z0 |= (z1 = mod_sub(range, *in++)) << 20;
+ z1 = (z1 >> 12) | (mod_sub(range, *in++) << 8);
+ z1 |= (z2 = mod_sub(range, *in++)) << 28;
+
+ out = OSSL_CRYPTO_store_u32_le(out, z0);
+ out = OSSL_CRYPTO_store_u32_le(out, z1);
+ out = OSSL_CRYPTO_store_u16_le(out, (uint16_t) (z2 >> 4));
+ } while (in < end);
return 1;
}
static int poly_decode_signed_two_to_power_19(POLY *p, PACKET *pkt)
{
int i, ret = 0;
- uint32_t a1, a2, a3 = 0;
uint32_t *out = p->coeff;
const uint8_t *in;
static const uint32_t range = 1u << 19;
static const uint32_t mask_20_bits = (1u << 20) - 1;
for (i = 0; i < (ML_DSA_NUM_POLY_COEFFICIENTS / 4); i++) {
+ uint32_t a1, a2;
+ uint16_t a3;
+
if (!PACKET_get_bytes(pkt, &in, 10))
goto err;
- memcpy(&a1, in, 4);
- memcpy(&a2, in + 4, 4);
- memcpy(&a3, in + 8, 2);
+ in = OSSL_CRYPTO_load_u32_le(&a1, in);
+ in = OSSL_CRYPTO_load_u32_le(&a2, in);
+ in = OSSL_CRYPTO_load_u16_le(&a3, in);
*out++ = mod_sub(range, a1 & mask_20_bits);
*out++ = mod_sub(range, (a1 >> 20) | ((a2 & 0xFF) << 12));
static const uint32_t range = 1u << 17;
const uint32_t *in = p->coeff, *end = in + ML_DSA_NUM_POLY_COEFFICIENTS;
- while (in < end) {
- uint32_t z0 = mod_sub(range, *in++); /* < 2^18 */
- uint32_t z1 = mod_sub(range, *in++);
- uint32_t z2 = mod_sub(range, *in++);
- uint32_t z3 = mod_sub(range, *in++);
-
- z0 |= (z1 << 18);
- z1 >>= 14;
- z1 |= (z2 << 4) | (z3 << 22);
- z3 >>= 10;
+ do {
+ uint8_t *out;
+ uint32_t z0, z1, z2;
- if (!WPACKET_memcpy(pkt, &z0, sizeof(z0))
- || !WPACKET_memcpy(pkt, &z1, sizeof(z1))
- || !WPACKET_memcpy(pkt, &z3, 1))
+ if (!WPACKET_allocate_bytes(pkt, 9, &out))
return 0;
- }
+
+ z0 = mod_sub(range, *in++);
+ z0 |= (z1 = mod_sub(range, *in++)) << 18;
+ z1 = (z1 >> 14) | (mod_sub(range, *in++) << 4);
+ z1 |= (z2 = mod_sub(range, *in++)) << 22;
+
+ out = OSSL_CRYPTO_store_u32_le(out, z0);
+ out = OSSL_CRYPTO_store_u32_le(out, z1);
+ *out = z2 >> 10;
+ } while (in < end);
return 1;
}
*/
static int poly_decode_signed_two_to_power_17(POLY *p, PACKET *pkt)
{
- int ret = 0;
- uint32_t a1, a2, a3 = 0;
uint32_t *out = p->coeff;
const uint32_t *end = out + ML_DSA_NUM_POLY_COEFFICIENTS;
const uint8_t *in;
static const uint32_t range = 1u << 17;
static const uint32_t mask_18_bits = (1u << 18) - 1;
- while (out < end) {
- if (!PACKET_get_bytes(pkt, &in, 10))
- goto err;
- memcpy(&a1, in, 4);
- memcpy(&a2, in + 4, 4);
- memcpy(&a3, in + 8, 1);
+ do {
+ uint32_t a1, a2, a3;
+
+ if (!PACKET_get_bytes(pkt, &in, 9))
+ return 0;
+ in = OSSL_CRYPTO_load_u32_le(&a1, in);
+ in = OSSL_CRYPTO_load_u32_le(&a2, in);
+ a3 = (uint32_t) *in;
*out++ = mod_sub(range, a1 & mask_18_bits);
*out++ = mod_sub(range, (a1 >> 18) | ((a2 & 0xF) << 14));
*out++ = mod_sub(range, (a2 >> 4) & mask_18_bits);
*out++ = mod_sub(range, (a2 >> 22) | (a3 << 10));
- }
- ret = 1;
- err:
- return ret;
+ } while (out < end);
+ return 1;
}
/*
uint32_t k = params->k, l = params->l;
uint32_t gamma1 = params->gamma1, gamma2 = params->gamma2;
uint8_t *alloc = NULL, *w1_encoded;
- size_t w1_encoded_len = 128 * k;
+ size_t alloc_len, w1_encoded_len;
size_t num_polys_sig_k = 2 * k;
size_t num_polys_k = 5 * k;
size_t num_polys_l = 3 * l;
size_t num_polys_k_by_l = k * l;
POLY *polys = NULL, *p, *c_ntt;
- size_t alloc_len = w1_encoded_len
- + sizeof(*polys)
- * (1 + num_polys_k + num_polys_l
- + num_polys_k_by_l + num_polys_sig_k);
VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y;
MATRIX a_ntt;
ML_DSA_SIG sig;
* Allocate a single blob for most of the variable size temporary variables.
* Mostly used for VECTOR POLYNOMIALS (every POLY is 1K).
*/
+ w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
+ alloc_len = w1_encoded_len
+ + sizeof(*polys) * (1 + num_polys_k + num_polys_l
+ + num_polys_k_by_l + num_polys_sig_k);
alloc = OPENSSL_malloc(alloc_len);
if (alloc == NULL)
return 0;
vector_high_bits(&w, gamma2, &w1);
ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len);
- if (!shake_xof_2(h_ctx, mu, sizeof(mu), w1_encoded, 128 * k,
+ if (!shake_xof_2(h_ctx, mu, sizeof(mu), w1_encoded, w1_encoded_len,
c_tilde, c_tilde_len))
break;
const ML_DSA_PARAMS *params = ctx->params;
uint32_t k = pub->params->k;
uint32_t l = pub->params->l;
- size_t w1_encoded_len = 128 * k;
+ uint32_t gamma2 = params->gamma2;
+ size_t w1_encoded_len;
size_t num_polys_sig = k + l;
size_t num_polys_k = 2 * k;
size_t num_polys_l = 1 * l;
uint32_t z_max;
/* Allocate space for all the POLYNOMIALS used by temporary VECTORS */
+ w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
alloc = OPENSSL_malloc(w1_encoded_len
+ sizeof(*polys) * (1 + num_polys_k
+ num_polys_l
/* compute w1_encoded */
w1 = w_approx;
- vector_use_hint(&sig.hint, w_approx, params->gamma2, w1);
- ossl_ml_dsa_w1_encode(w1, params->gamma2, w1_encoded, w1_encoded_len);
+ vector_use_hint(&sig.hint, w_approx, gamma2, w1);
+ ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len);
if (!shake_xof_3(h_ctx, mu, sizeof(mu), w1_encoded, w1_encoded_len, NULL, 0,
c_tilde, c_tilde_len))