From: slontis Date: Tue, 12 Nov 2024 07:35:10 +0000 (+1100) Subject: Update SLH-DSA code to use PACKET and WPACKET. X-Git-Tag: openssl-3.5.0-alpha1~187 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=148f4d23e1a9becf8984ddc92fa8ebcb3b760bd9;p=thirdparty%2Fopenssl.git Update SLH-DSA code to use PACKET and WPACKET. Reviewed-by: Paul Dale Reviewed-by: Viktor Dukhovni Reviewed-by: Tim Hudson (Merged from https://github.com/openssl/openssl/pull/25882) --- diff --git a/crypto/slh_dsa/slh_dsa.c b/crypto/slh_dsa/slh_dsa.c index 10b0b1f66b0..0fdef850349 100644 --- a/crypto/slh_dsa/slh_dsa.c +++ b/crypto/slh_dsa/slh_dsa.c @@ -12,15 +12,12 @@ #include "slh_dsa_local.h" #include "slh_dsa_key.h" -#define SLH_MAX_M 49 +#define SLH_MAX_M 49 /* See slh_params.c */ +/* The size of md is (21..40 bytes) - since a is in bits round up to nearest byte */ +#define MD_LEN(params) (((params)->k * (params)->a + 7) >> 3) -/* (n + SLH_SIG_FORS_LEN(k, a, n) + SLH_SIG_HT_LEN(n, hm, d)) */ -#define SLH_SIG_RANDOM_LEN(n) (n) -#define SLH_SIG_FORS_LEN(k, a, n) (n) * ((k) * (1 + (a))) -#define SLH_SIG_HT_LEN(h, d, n) (n) * ((h) + (d) * SLH_WOTS_LEN(n)) - -static void get_tree_ids(const uint8_t *digest, const SLH_DSA_PARAMS *params, - uint64_t *tree_id, uint32_t *leaf_id); +static int get_tree_ids(PACKET *pkt, const SLH_DSA_PARAMS *params, + uint64_t *tree_id, uint32_t *leaf_id); /** * @brief SLH-DSA Signature generation @@ -46,23 +43,23 @@ static int slh_sign_internal(SLH_DSA_CTX *ctx, const SLH_DSA_KEY *priv, uint8_t *sig, size_t *sig_len, size_t sig_size, const uint8_t *opt_rand) { + int ret = 0; const SLH_DSA_PARAMS *params = ctx->params; - uint32_t n = params->n; - size_t r_len = n; - size_t sig_fors_len = SLH_SIG_FORS_LEN(params->k, params->a, n); - size_t sig_ht_len = SLH_SIG_HT_LEN(params->h, params->d, n); - size_t sig_len_expected = r_len + sig_fors_len + sig_ht_len; + size_t sig_len_expected = params->sig_len; SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx); SLH_ADRS_FUNC_DECLARE(ctx, adrsf); SLH_ADRS_DECLARE(adrs); + uint8_t m_digest[SLH_MAX_M]; + const uint8_t *md; /* The first md_len bytes of m_digest */ + size_t md_len = MD_LEN(params); /* The size of the digest |md| */ + /* Points to |m_digest| buffer, it is also reused to point to |sig_fors| */ + PACKET r_packet, *rpkt = &r_packet; + uint8_t *r, *sig_fors; /* Pointers into buffer inside |wpkt| */ + WPACKET w_packet, *wpkt = &w_packet; /* Points to output |sig| buffer */ + const uint8_t *pk_seed, *sk_seed; /* pointers to elements within |priv| */ + uint8_t pk_fors[SLH_MAX_N]; uint64_t tree_id; uint32_t leaf_id; - uint8_t pk_fors[SLH_MAX_N]; - uint8_t m_digest[SLH_MAX_M]; - uint8_t *r = sig; - uint8_t *sig_fors = r + r_len; - uint8_t *sig_ht = sig_fors + sig_fors_len; - const uint8_t *md, *pk_seed, *sk_seed; if (sig_len != NULL) *sig_len = sig_len_expected; @@ -76,6 +73,11 @@ static int slh_sign_internal(SLH_DSA_CTX *ctx, const SLH_DSA_KEY *priv, if (priv->has_priv == 0) return 0; + if (!WPACKET_init_static_len(wpkt, sig, sig_len_expected, 0)) + return 0; + if (!PACKET_buf_init(rpkt, m_digest, params->m)) + return 0; + pk_seed = SLH_DSA_PK_SEED(priv); sk_seed = SLH_DSA_SK_SEED(priv); @@ -83,26 +85,38 @@ static int slh_sign_internal(SLH_DSA_CTX *ctx, const SLH_DSA_KEY *priv, opt_rand = pk_seed; adrsf->zero(adrs); - /* calculate Randomness value r, and output to the signature */ - if (!hashf->PRF_MSG(hctx, SLH_DSA_SK_PRF(priv), opt_rand, msg, msg_len, r) + /* calculate Randomness value r, and output to the SLH-DSA signature */ + r = WPACKET_get_curr(wpkt); + if (!hashf->PRF_MSG(hctx, SLH_DSA_SK_PRF(priv), opt_rand, msg, msg_len, wpkt) /* generate a digest of size |params->m| bytes where m is (30..49) */ || !hashf->H_MSG(hctx, r, pk_seed, SLH_DSA_PK_ROOT(priv), msg, msg_len, - m_digest)) - return 0; - /* Grab selected bytes from the digest to select tree and leaf id's */ - get_tree_ids(m_digest, params, &tree_id, &leaf_id); + m_digest, sizeof(m_digest)) + /* Grab the first md_len bytes of m_digest to use in fors_sign() */ + || !PACKET_get_bytes(rpkt, &md, md_len) + /* Grab remaining bytes from m_digest to select tree and leaf id's */ + || !get_tree_ids(rpkt, params, &tree_id, &leaf_id)) + goto err; adrsf->set_tree_address(adrs, tree_id); adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_FORS_TREE); adrsf->set_keypair_address(adrs, leaf_id); - /* generate the FORS signature and append it to the signature */ - md = m_digest; - return ossl_slh_fors_sign(ctx, md, sk_seed, pk_seed, adrs, sig_fors, sig_fors_len) - /* Calculate the FORS public key */ - && ossl_slh_fors_pk_from_sig(ctx, sig_fors, md, pk_seed, adrs, pk_fors) + sig_fors = WPACKET_get_curr(wpkt); + /* generate the FORS signature and append it to the SLH-DSA signature */ + ret = ossl_slh_fors_sign(ctx, md, sk_seed, pk_seed, adrs, wpkt) + /* Reuse rpkt to point to the FORS signature that was just generated */ + && PACKET_buf_init(rpkt, sig_fors, WPACKET_get_curr(wpkt) - sig_fors) + /* Calculate the FORS public key using the generated FORS signature */ + && ossl_slh_fors_pk_from_sig(ctx, rpkt, md, pk_seed, adrs, + pk_fors, sizeof(pk_fors)) + /* Generate ht signature and append to the SLH-DSA signature */ && ossl_slh_ht_sign(ctx, pk_fors, sk_seed, pk_seed, tree_id, leaf_id, - sig_ht, sig_ht_len); + wpkt); + ret = 1; + err: + if (!WPACKET_finish(wpkt)) + ret = 0; + return ret; } /** @@ -129,43 +143,57 @@ static int slh_verify_internal(SLH_DSA_CTX *ctx, const SLH_DSA_KEY *pub, SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx); SLH_ADRS_FUNC_DECLARE(ctx, adrsf); SLH_ADRS_DECLARE(adrs); - uint8_t mdigest[SLH_MAX_M]; + const SLH_DSA_PARAMS *params = ctx->params; + uint32_t n = params->n; + const uint8_t *pk_seed, *pk_root; /* Pointers to elements in |pub| */ + PACKET pkt, *sig_rpkt = &pkt; /* Points to the |sig| buffer */ + uint8_t m_digest[SLH_MAX_M]; + const uint8_t *md; /* This is a pointer into the buffer in m_digest_rpkt */ + size_t md_len = MD_LEN(params); /* 21..40 bytes */ + PACKET pkt2, *m_digest_rpkt = &pkt2; /* Points to m_digest buffer */ + const uint8_t *r; /* Pointer to |sig_rpkt| buffer */ uint8_t pk_fors[SLH_MAX_N]; uint64_t tree_id; uint32_t leaf_id; - const SLH_DSA_PARAMS *params = ctx->params; - uint32_t n = params->n; - size_t r_len = SLH_SIG_RANDOM_LEN(n); - size_t sig_fors_len = SLH_SIG_FORS_LEN(params->k, params->a, n); - size_t sig_ht_len = SLH_SIG_HT_LEN(params->h, params->d, n); - const uint8_t *r, *sig_fors, *sig_ht, *md, *pk_seed, *pk_root; - if (sig_len != (r_len + sig_fors_len + sig_ht_len)) - return 0; /* Exit if public key is not set */ if (pub->key_len == 0) return 0; - adrsf->zero(adrs); + /* Exit if signature is invalid size */ + if (sig_len != ctx->params->sig_len + || !PACKET_buf_init(sig_rpkt, sig, sig_len)) + return 0; + if (!PACKET_get_bytes(sig_rpkt, &r, n)) + return 0; - r = sig; - sig_fors = r + r_len; - sig_ht = sig_fors + sig_fors_len; + adrsf->zero(adrs); pk_seed = SLH_DSA_PK_SEED(pub); pk_root = SLH_DSA_PK_ROOT(pub); - if (!hashf->H_MSG(hctx, r, pk_seed, pk_root, msg, msg_len, mdigest)) + if (!hashf->H_MSG(hctx, r, pk_seed, pk_root, msg, msg_len, + m_digest, sizeof(m_digest))) + return 0; + + /* + * Get md (the first md_len bytes of m_digest to use in + * ossl_slh_fors_pk_from_sig(), and then retrieve the tree id and leaf id + * from the remaining bytes in m_digest. + */ + if (!PACKET_buf_init(m_digest_rpkt, m_digest, sizeof(m_digest)) + || !PACKET_get_bytes(m_digest_rpkt, &md, md_len) + || !get_tree_ids(m_digest_rpkt, params, &tree_id, &leaf_id)) return 0; - md = mdigest; - get_tree_ids(mdigest, params, &tree_id, &leaf_id); adrsf->set_tree_address(adrs, tree_id); adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_FORS_TREE); adrsf->set_keypair_address(adrs, leaf_id); - return ossl_slh_fors_pk_from_sig(ctx, sig_fors, md, pk_seed, adrs, pk_fors) - && ossl_slh_ht_verify(ctx, pk_fors, sig_ht, pk_seed, - tree_id, leaf_id, pk_root); + return ossl_slh_fors_pk_from_sig(ctx, sig_rpkt, md, pk_seed, adrs, + pk_fors, sizeof(pk_fors)) + && ossl_slh_ht_verify(ctx, pk_fors, sig_rpkt, pk_seed, + tree_id, leaf_id, pk_root) + && PACKET_remaining(sig_rpkt) == 0; } /** @@ -288,21 +316,20 @@ static uint64_t bytes_to_u64_be(const uint8_t *in, size_t in_len) * Converts digested bytes into a tree index, and leaf index within the tree. * The sizes are determined by the |params| parameter set. */ -static void get_tree_ids(const uint8_t *digest, const SLH_DSA_PARAMS *params, - uint64_t *tree_id, uint32_t *leaf_id) +static int get_tree_ids(PACKET *rpkt, const SLH_DSA_PARAMS *params, + uint64_t *tree_id, uint32_t *leaf_id) { const uint8_t *tree_id_bytes, *leaf_id_bytes; - uint32_t md_len, tree_id_len, leaf_id_len; + uint32_t tree_id_len, leaf_id_len; uint64_t tree_id_mask, leaf_id_mask; - md_len = ((params->k * params->a + 7) >> 3); /* 21..40 bytes */ tree_id_len = ((params->h - params->hm + 7) >> 3); /* 7 or 8 bytes */ leaf_id_len = ((params->hm + 7) >> 3); /* 1 or 2 bytes */ - tree_id_bytes = digest + md_len; - leaf_id_bytes = tree_id_bytes + tree_id_len; + if (!PACKET_get_bytes(rpkt, &tree_id_bytes, tree_id_len) + || !PACKET_get_bytes(rpkt, &leaf_id_bytes, leaf_id_len)) + return 0; - assert((md_len + tree_id_len + leaf_id_len) == params->m); /* * In order to calculate A mod (2^X) where X is in the range of (54..64) * This is equivalent to A & (2^x - 1) which is just a sequence of X ones @@ -315,4 +342,5 @@ static void get_tree_ids(const uint8_t *digest, const SLH_DSA_PARAMS *params, leaf_id_mask = (1 << params->hm) - 1; /* max value is 0x1FF when hm = 9 */ *tree_id = bytes_to_u64_be(tree_id_bytes, tree_id_len) & tree_id_mask; *leaf_id = (uint32_t)(bytes_to_u64_be(leaf_id_bytes, leaf_id_len) & leaf_id_mask); + return 1; } diff --git a/crypto/slh_dsa/slh_dsa_key.c b/crypto/slh_dsa/slh_dsa_key.c index e189c8e55e4..cf4f5c8787b 100644 --- a/crypto/slh_dsa/slh_dsa_key.c +++ b/crypto/slh_dsa/slh_dsa_key.c @@ -204,13 +204,12 @@ static int slh_dsa_compute_pk_root(SLH_DSA_CTX *ctx, SLH_DSA_KEY *out) SLH_ADRS_DECLARE(adrs); const SLH_DSA_PARAMS *params = out->params; - assert(params != NULL); - adrsf->zero(adrs); adrsf->set_layer_address(adrs, params->d - 1); /* Generate the ROOT public key */ return ossl_slh_xmss_node(ctx, SLH_DSA_SK_SEED(out), 0, params->hm, - SLH_DSA_PK_SEED(out), adrs, SLH_DSA_PK_ROOT(out)); + SLH_DSA_PK_SEED(out), adrs, + SLH_DSA_PK_ROOT(out), params->n); } /** diff --git a/crypto/slh_dsa/slh_dsa_local.h b/crypto/slh_dsa/slh_dsa_local.h index 59c5fb605d5..8b4a40941d6 100644 --- a/crypto/slh_dsa/slh_dsa_local.h +++ b/crypto/slh_dsa/slh_dsa_local.h @@ -46,42 +46,43 @@ struct slh_dsa_ctx_st { SLH_HASH_CTX hash_ctx; }; -__owur int ossl_slh_wots_pk_gen(SLH_DSA_CTX *ctx, - const uint8_t *sk_seed, const uint8_t *pk_seed, - SLH_ADRS adrs, uint8_t *pk_out); +__owur int ossl_slh_wots_pk_gen(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, + const uint8_t *pk_seed, SLH_ADRS adrs, + uint8_t *pk_out, size_t pk_out_len); __owur int ossl_slh_wots_sign(SLH_DSA_CTX *ctx, const uint8_t *msg, const uint8_t *sk_seed, const uint8_t *pk_seed, - SLH_ADRS adrs, uint8_t *sig, size_t sig_len); + SLH_ADRS adrs, WPACKET *sig_wpkt); __owur int ossl_slh_wots_pk_from_sig(SLH_DSA_CTX *ctx, - const uint8_t *sig, const uint8_t *msg, - const uint8_t *pk_seed, uint8_t *adrs, - uint8_t *pk_out); + PACKET *sig_rpkt, const uint8_t *msg, + const uint8_t *pk_seed, SLH_ADRS adrs, + uint8_t *pk_out, size_t pk_out_len); __owur int ossl_slh_xmss_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, uint32_t node_id, uint32_t height, const uint8_t *pk_seed, SLH_ADRS adrs, - uint8_t *pk_out); + uint8_t *pk_out, size_t pk_out_len); __owur int ossl_slh_xmss_sign(SLH_DSA_CTX *ctx, const uint8_t *msg, const uint8_t *sk_seed, uint32_t node_id, const uint8_t *pk_seed, SLH_ADRS adrs, - uint8_t *sig, size_t sig_len); + WPACKET *sig_wpkt); __owur int ossl_slh_xmss_pk_from_sig(SLH_DSA_CTX *ctx, uint32_t node_id, - const uint8_t *sig, const uint8_t *msg, + PACKET *sig_rpkt, const uint8_t *msg, const uint8_t *pk_seed, SLH_ADRS adrs, - uint8_t *pk_out); + uint8_t *pk_out, size_t pk_out_len); __owur int ossl_slh_ht_sign(SLH_DSA_CTX *ctx, const uint8_t *msg, const uint8_t *sk_seed, const uint8_t *pk_seed, uint64_t tree_id, uint32_t leaf_id, - uint8_t *sig_out, size_t sig_out_len); + WPACKET *sig_wpkt); __owur int ossl_slh_ht_verify(SLH_DSA_CTX *ctx, const uint8_t *msg, - const uint8_t *sig, const uint8_t *pk_seed, + PACKET *sig_rpkt, const uint8_t *pk_seed, uint64_t tree_id, uint32_t leaf_id, const uint8_t *pk_root); __owur int ossl_slh_fors_sign(SLH_DSA_CTX *ctx, const uint8_t *md, const uint8_t *sk_seed, const uint8_t *pk_seed, - SLH_ADRS adrs, uint8_t *sig, size_t sig_len); -__owur int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, const uint8_t *sig, + SLH_ADRS adrs, WPACKET *sig_wpkt); +__owur int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, PACKET *sig_rpkt, const uint8_t *md, const uint8_t *pk_seed, - SLH_ADRS adrs, uint8_t *pk_out); + SLH_ADRS adrs, + uint8_t *pk_out, size_t pk_out_len); diff --git a/crypto/slh_dsa/slh_fors.c b/crypto/slh_dsa/slh_fors.c index 499a5ff0676..e3f8655b1cb 100644 --- a/crypto/slh_dsa/slh_fors.c +++ b/crypto/slh_dsa/slh_fors.c @@ -35,11 +35,12 @@ static void slh_base_2b(const uint8_t *in, uint32_t b, uint32_t *out, size_t out * @param id The index of the FORS secret value within the sets of FORS trees. * (which must be < 2^(hm - height) * @param pk_out The generated FORS secret value of size |n| + * @param pk_out_len The maximum size of |pk_out| * @returns 1 on success, or 0 on error. */ static int slh_fors_sk_gen(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, const uint8_t *pk_seed, SLH_ADRS adrs, uint32_t id, - uint8_t *sk_out) + uint8_t *pk_out, size_t pk_out_len) { SLH_ADRS_DECLARE(sk_adrs); SLH_ADRS_FUNC_DECLARE(ctx, adrsf); @@ -48,7 +49,8 @@ static int slh_fors_sk_gen(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, adrsf->set_type_and_clear(sk_adrs, SLH_ADRS_TYPE_FORS_PRF); adrsf->copy_keypair_address(sk_adrs, adrs); adrsf->set_tree_index(sk_adrs, id); - return ctx->hash_func->PRF(&ctx->hash_ctx, pk_seed, sk_seed, sk_adrs, sk_out); + return ctx->hash_func->PRF(&ctx->hash_ctx, pk_seed, sk_seed, sk_adrs, + pk_out, pk_out_len); } /** @@ -69,11 +71,12 @@ static int slh_fors_sk_gen(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, * @param node_id The target node index * @param height The target node height * @param node The returned hash for a node of size|n| + * @param node_len The maximum size of |node| * @returns 1 on success, or 0 on error. */ static int slh_fors_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, const uint8_t *pk_seed, SLH_ADRS adrs, uint32_t node_id, - uint32_t height, uint8_t *node) + uint32_t height, uint8_t *node, size_t node_len) { int ret = 0; SLH_ADRS_FUNC_DECLARE(ctx, adrsf); @@ -81,22 +84,26 @@ static int slh_fors_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, uint32_t n = ctx->params->n; if (height == 0) { - if (!slh_fors_sk_gen(ctx, sk_seed, pk_seed, adrs, node_id, sk)) + /* Gets here for leaf nodes */ + if (!slh_fors_sk_gen(ctx, sk_seed, pk_seed, adrs, node_id, + sk, sizeof(sk))) return 0; adrsf->set_tree_height(adrs, 0); adrsf->set_tree_index(adrs, node_id); - ret = ctx->hash_func->F(&ctx->hash_ctx, pk_seed, adrs, sk, n, node); + ret = ctx->hash_func->F(&ctx->hash_ctx, pk_seed, adrs, sk, n, + node, node_len); OPENSSL_cleanse(sk, n); return ret; } else { if (!slh_fors_node(ctx, sk_seed, pk_seed, adrs, 2 * node_id, height - 1, - lnode) + lnode, sizeof(rnode)) || !slh_fors_node(ctx, sk_seed, pk_seed, adrs, 2 * node_id + 1, - height - 1, rnode)) + height - 1, rnode, sizeof(rnode))) return 0; adrsf->set_tree_height(adrs, height); adrsf->set_tree_index(adrs, node_id); - if (!ctx->hash_func->H(&ctx->hash_ctx, pk_seed, adrs, lnode, rnode, node)) + if (!ctx->hash_func->H(&ctx->hash_ctx, pk_seed, adrs, lnode, rnode, + node, node_len)) return 0; } return 1; @@ -106,6 +113,10 @@ static int slh_fors_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, * @brief Generate an FORS signature * See FIPS 205 Section 8.3 Algorithm 16 * + * A FORS signature has a size of (k * (1 + a) * n) bytes + * There are k trees, each of which have a private key value of size |n| followed + * by an authentication path of size |a| (where each path is size |n|) + * * @param ctx Contains SLH_DSA algorithm functions and constants. * @param md A message digest of size |(k * a + 7) / 8| bytes to sign * @param sk_seed A private key seed of size |n| @@ -114,24 +125,23 @@ static int slh_fors_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, * tree address set to the XMSS tree that signs the FORS key, * the type set to FORS_TREE, and the keypair address set to the * index of the WOTS+ key that signs the FORS key. - * @param sig_out The generated XMSS signature which consists of a WOTS+ - * signature and authentication path + * @param sig_wpkt A WPACKET object to write the generated XMSS signature to * @param sig_len The size of |sig| which is (2 * n + 3) * n + tree_height * n. * @returns 1 on success, or 0 on error. */ int ossl_slh_fors_sign(SLH_DSA_CTX *ctx, const uint8_t *md, const uint8_t *sk_seed, const uint8_t *pk_seed, - SLH_ADRS adrs, uint8_t *sig, size_t sig_len) + SLH_ADRS adrs, WPACKET *sig_wpkt) { - uint32_t i, j, s; + uint32_t tree_id, layer, s, tree_offset; uint32_t ids[SLH_MAX_K]; const SLH_DSA_PARAMS *params = ctx->params; uint32_t n = params->n; - uint32_t k = params->k; + uint32_t k = params->k; /* number of trees */ uint32_t a = params->a; - uint32_t t = (1 << a); - uint32_t t_times_i = 0; - uint8_t *psig = sig; + uint32_t two_power_a = (1 << a); /* this is t in FIPS 205 */ + uint32_t tree_id_times_two_power_a = 0; + uint8_t out[SLH_MAX_N]; /* * Split md into k a-bit values e.g with k = 14, a = 12 @@ -139,25 +149,42 @@ int ossl_slh_fors_sign(SLH_DSA_CTX *ctx, const uint8_t *md, */ slh_base_2b(md, a, ids, k); - for (i = 0; i < k; ++i) { - uint32_t id = ids[i]; /* |id| = |a| bits */ + for (tree_id = 0; tree_id < k; ++tree_id) { + /* Get the tree[i] leaf id */ + uint32_t node_id = ids[tree_id]; /* |id| = |a| bits */ + + /* + * Give each of the k trees a unique range at each level. + * e.g. If we have 4096 leaf nodes (2^a = 2^12) for each tree + * tree i will use indexes from 4096 * i + (0..4095) for its bottom level. + * For the next level up from the bottom there would be 2048 nodes + * (so tree i uses indexes 2048 * i + (0...2047) for this level) + */ + tree_offset = tree_id_times_two_power_a; if (!slh_fors_sk_gen(ctx, sk_seed, pk_seed, adrs, - id + t_times_i, psig)) + node_id + tree_id_times_two_power_a, out, sizeof(out)) + || !WPACKET_memcpy(sig_wpkt, out, n)) return 0; - psig += n; - for (j = 0; j < a; ++j) { - s = id ^ 1; - if (!slh_fors_node(ctx, sk_seed, pk_seed, adrs, s + i * (1 << (a - j)), - j, psig)) + /* + * Traverse from the bottom of the tree (layer = 0) + * up to the root (layer = a - 1). + * TODO - This is a really inefficient way of doing this, since at + * layer a - 1 it calculates most of the hashes of the entire tree as + * well as all the leaf nodes. So it is calculating nodes multiple times. + */ + for (layer = 0; layer < a; ++layer) { + s = node_id ^ 1; /* XOR gets the index of the other child in a binary tree */ + if (!slh_fors_node(ctx, sk_seed, pk_seed, adrs, + s + tree_offset, layer, out, sizeof(out))) return 0; - id >>= 1; - psig += n; + node_id >>= 1;/* Get the parent node id */ + tree_offset >>= 1; /* Each layer up has half as many nodes */ + WPACKET_memcpy(sig_wpkt, out, n); } - t_times_i += t; + tree_id_times_two_power_a += two_power_a; } - assert((size_t)(psig - sig) == sig_len); return 1; } @@ -165,8 +192,10 @@ int ossl_slh_fors_sign(SLH_DSA_CTX *ctx, const uint8_t *md, * @brief Compute a candidate FORS public key from a message and signature. * See FIPS 205 Section 8.4 Algorithm 17. * + * A FORS signature has a size of (k * (a + 1) * n) bytes + * * @param ctx Contains SLH_DSA algorithm functions and constants. - * @param sig A FORS signature of size (k * (a + 1) * n) bytes + * @param fors_sig_rpkt A PACKET object to read a FORS signature from * @param md A message digest of size (k * a / 8) bytes * @param pk_seed A public key seed of size |n| * @param adrs The ADRS object must have a layer address of zero, and the @@ -174,12 +203,14 @@ int ossl_slh_fors_sign(SLH_DSA_CTX *ctx, const uint8_t *md, * the type set to FORS_TREE, and the keypair address set to the * index of the WOTS+ key that signs the FORS key. * @param pk_out The returned candidate FORS public key of size |n| + * @param pk_out_len The maximum size of |pk_out| * @returns 1 on success, or 0 on error. */ -int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, const uint8_t *sig, +int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, PACKET *fors_sig_rpkt, const uint8_t *md, const uint8_t *pk_seed, - SLH_ADRS adrs, uint8_t *pk_out) + SLH_ADRS adrs, uint8_t *pk_out, size_t pk_out_len) { + int ret = 0; SLH_ADRS_DECLARE(pk_adrs); SLH_ADRS_FUNC_DECLARE(ctx, adrsf); SLH_ADRS_FN_DECLARE(adrsf, set_tree_index); @@ -189,12 +220,19 @@ int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, const uint8_t *sig, SLH_HASH_FN_DECLARE(hashf, H); uint32_t i, j, aoff = 0; uint32_t ids[SLH_MAX_K]; - uint8_t roots[SLH_MAX_ROOTS], *node = roots; const SLH_DSA_PARAMS *params = ctx->params; uint32_t a = params->a; uint32_t k = params->k; uint32_t n = params->n; uint32_t two_power_a = (1 << a); + const uint8_t *sk, *authj; /* Pointers to |sig| buffer inside fors_sig_rpkt */ + uint8_t roots[SLH_MAX_ROOTS]; + size_t roots_len = 0; /* The size of |roots| */ + uint8_t *node0, *node1; /* Pointers into roots[] */ + WPACKET root_pkt, *wroot_pkt = &root_pkt; /* Points to |roots| buffer */ + + if (!WPACKET_init_static_len(wroot_pkt, roots, sizeof(roots), 0)) + return 0; /* Split md into k a-bit values e.g ids[0..k-1] = 12 bits each of md */ slh_base_2b(md, a, ids, k); @@ -206,36 +244,49 @@ int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, const uint8_t *sig, set_tree_height(adrs, 0); set_tree_index(adrs, node_id); - if (!F(hctx, pk_seed, adrs, sig, n, node)) - return 0; - sig += n; + /* Regenerate the public key of the leaf */ + if (!PACKET_get_bytes(fors_sig_rpkt, &sk, n) + || !WPACKET_allocate_bytes(wroot_pkt, n, &node0) + || !F(hctx, pk_seed, adrs, sk, n, node0, n)) + goto err; + + /* This omits the copying of the nodes that the FIPS 205 code does */ + node1 = node0; for (j = 0; j < a; ++j) { + /* Get this layers other child public key */ + if (!PACKET_get_bytes(fors_sig_rpkt, &authj, n)) + goto err; + /* Hash the children together to get the parent nodes public key */ set_tree_height(adrs, j + 1); if ((id & 1) == 0) { node_id >>= 1; set_tree_index(adrs, node_id); - if (!H(hctx, pk_seed, adrs, node, sig, node)) - return 0; + if (!H(hctx, pk_seed, adrs, node0, authj, node1, n)) + goto err; } else { node_id = (node_id - 1) >> 1; set_tree_index(adrs, node_id); - if (!H(hctx, pk_seed, adrs, sig, node, node)) - return 0; + if (!H(hctx, pk_seed, adrs, authj, node0, node1, n)) + goto err; } id >>= 1; - sig += n; } aoff += two_power_a; - node += n; } - assert((size_t)(node - roots) <= sizeof(roots)); + if (!WPACKET_get_total_written(wroot_pkt, &roots_len)) + goto err; /* The public key is the hash of all the roots of the k trees */ adrsf->copy(pk_adrs, adrs); adrsf->set_type_and_clear(pk_adrs, SLH_ADRS_TYPE_FORS_ROOTS); adrsf->copy_keypair_address(pk_adrs, adrs); - return hashf->T(hctx, pk_seed, pk_adrs, roots, node - roots, pk_out); + ret = hashf->T(hctx, pk_seed, pk_adrs, roots, roots_len, + pk_out, pk_out_len); + err: + if (!WPACKET_finish(wroot_pkt)) + ret = 0; + return ret; } /** diff --git a/crypto/slh_dsa/slh_hash.c b/crypto/slh_dsa/slh_hash.c index 9eafb42002d..7f240276d5f 100644 --- a/crypto/slh_dsa/slh_hash.c +++ b/crypto/slh_dsa/slh_hash.c @@ -160,21 +160,25 @@ static ossl_inline int xof_digest_4(EVP_MD_CTX *ctx, static int slh_hmsg_shake(SLH_HASH_CTX *hctx, const uint8_t *r, const uint8_t *pk_seed, const uint8_t *pk_root, const uint8_t *msg, size_t msg_len, - uint8_t *out) + uint8_t *out, size_t out_len) { size_t m = hctx->m; size_t n = hctx->n; + assert(m <= out_len); + return xof_digest_4(hctx->md_ctx, r, n, pk_seed, n, pk_root, n, msg, msg_len, out, m); } static int slh_prf_shake(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *sk_seed, - const SLH_ADRS adrs, uint8_t *out) + const SLH_ADRS adrs, uint8_t *out, size_t out_len) { size_t n = hctx->n; + assert(n <= out_len); + return xof_digest_3(hctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE, sk_seed, n, out, n); } @@ -182,40 +186,50 @@ slh_prf_shake(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *sk_seed static int slh_prf_msg_shake(SLH_HASH_CTX *hctx, const uint8_t *sk_prf, const uint8_t *opt_rand, const uint8_t *msg, size_t msg_len, - uint8_t *out) + WPACKET *pkt) { + unsigned char out[SLH_MAX_N]; size_t n = hctx->n; + assert(n <= sizeof(out)); + return xof_digest_3(hctx->md_ctx, sk_prf, n, opt_rand, n, - msg, msg_len, out, n); + msg, msg_len, out, n) + && WPACKET_memcpy(pkt, out, n); } static int slh_f_shake(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs, - const uint8_t *m1, size_t m1_len, uint8_t *out) + const uint8_t *m1, size_t m1_len, uint8_t *out, size_t out_len) { size_t n = hctx->n; + assert(n <= out_len); + return xof_digest_3(hctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE, m1, m1_len, out, n); } static int slh_h_shake(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs, - const uint8_t *m1, const uint8_t *m2, uint8_t *out) + const uint8_t *m1, const uint8_t *m2, uint8_t *out, size_t out_len) { size_t n = hctx->n; + assert(n <= out_len); + return xof_digest_4(hctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE, m1, n, m2, n, out, n); } static int slh_t_shake(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs, - const uint8_t *ml, size_t ml_len, uint8_t *out) + const uint8_t *ml, size_t ml_len, uint8_t *out, size_t out_len) { size_t n = hctx->n; + assert(n <= out_len); + return xof_digest_3(hctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE, ml, ml_len, out, n); } @@ -239,13 +253,15 @@ digest_4(EVP_MD_CTX *ctx, static int slh_hmsg_sha2(SLH_HASH_CTX *hctx, const uint8_t *r, const uint8_t *pk_seed, const uint8_t *pk_root, const uint8_t *msg, size_t msg_len, - uint8_t *out) + uint8_t *out, size_t out_len) { + size_t m = hctx->m; size_t n = hctx->n; uint8_t seed[2 * SLH_MAX_N + MAX_DIGEST_SIZE]; int sz = EVP_MD_get_size(hctx->md); size_t seed_len = (size_t)sz + 2 * n; + assert(m <= out_len); assert(sz > 0); assert(seed_len <= sizeof(seed)); @@ -253,13 +269,13 @@ slh_hmsg_sha2(SLH_HASH_CTX *hctx, const uint8_t *r, const uint8_t *pk_seed, memcpy(seed + n, pk_seed, n); return digest_4(hctx->md_big_ctx, r, n, pk_seed, n, pk_root, n, msg, msg_len, seed + 2 * n) - && (PKCS1_MGF1(out, hctx->m, seed, seed_len, hctx->md) == 0); + && (PKCS1_MGF1(out, m, seed, seed_len, hctx->md) == 0); } static int slh_prf_msg_sha2(SLH_HASH_CTX *hctx, const uint8_t *sk_prf, const uint8_t *opt_rand, - const uint8_t *msg, size_t msg_len, uint8_t *out) + const uint8_t *msg, size_t msg_len, WPACKET *pkt) { int ret; EVP_MAC_CTX *mctx = hctx->hmac_ctx; @@ -288,19 +304,20 @@ slh_prf_msg_sha2(SLH_HASH_CTX *hctx, ret = EVP_MAC_init(mctx, sk_prf, n, p) == 1 && EVP_MAC_update(mctx, opt_rand, n) == 1 && EVP_MAC_update(mctx, msg, msg_len) == 1 - && EVP_MAC_final(mctx, mac, NULL, sizeof(mac)) == 1; - memcpy(out, mac, n); /* Truncate output to n bytes */ + && EVP_MAC_final(mctx, mac, NULL, sizeof(mac)) == 1 + && WPACKET_memcpy(pkt, mac, n); /* Truncate output to n bytes */ return ret; } static ossl_inline int do_hash(EVP_MD_CTX *ctx, size_t n, const uint8_t *pk_seed, const SLH_ADRS adrs, - const uint8_t *m, size_t m_len, size_t b, uint8_t *out) + const uint8_t *m, size_t m_len, size_t b, uint8_t *out, size_t out_len) { int ret; uint8_t zeros[128] = { 0 }; uint8_t digest[MAX_DIGEST_SIZE]; + assert(n <= out_len); assert(b - n < sizeof(zeros)); ret = digest_4(ctx, pk_seed, n, zeros, b - n, adrs, SLH_ADRSC_SIZE, @@ -312,25 +329,26 @@ do_hash(EVP_MD_CTX *ctx, size_t n, const uint8_t *pk_seed, const SLH_ADRS adrs, static int slh_prf_sha2(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, - const uint8_t *sk_seed, const SLH_ADRS adrs, uint8_t *out) + const uint8_t *sk_seed, const SLH_ADRS adrs, + uint8_t *out, size_t out_len) { size_t n = hctx->n; return do_hash(hctx->md_ctx, n, pk_seed, adrs, sk_seed, n, - SHA2_NUM_ZEROS_BOUND1, out); + SHA2_NUM_ZEROS_BOUND1, out, out_len); } static int slh_f_sha2(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs, - const uint8_t *m1, size_t m1_len, uint8_t *out) + const uint8_t *m1, size_t m1_len, uint8_t *out, size_t out_len) { return do_hash(hctx->md_ctx, hctx->n, pk_seed, adrs, m1, m1_len, - SHA2_NUM_ZEROS_BOUND1, out); + SHA2_NUM_ZEROS_BOUND1, out, out_len); } static int slh_h_sha2(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs, - const uint8_t *m1, const uint8_t *m2, uint8_t *out) + const uint8_t *m1, const uint8_t *m2, uint8_t *out, size_t out_len) { uint8_t m[SLH_MAX_N * 2]; size_t n = hctx->n; @@ -338,15 +356,15 @@ slh_h_sha2(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs, memcpy(m, m1, n); memcpy(m + n, m2, n); return do_hash(hctx->md_big_ctx, n, pk_seed, adrs, m, 2 * n, - hctx->sha2_h_and_t_bound, out); + hctx->sha2_h_and_t_bound, out, out_len); } static int slh_t_sha2(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs, - const uint8_t *ml, size_t ml_len, uint8_t *out) + const uint8_t *ml, size_t ml_len, uint8_t *out, size_t out_len) { return do_hash(hctx->md_big_ctx, hctx->n, pk_seed, adrs, ml, ml_len, - hctx->sha2_h_and_t_bound, out); + hctx->sha2_h_and_t_bound, out, out_len); } const SLH_HASH_FUNC *ossl_slh_get_hash_fn(int is_shake) diff --git a/crypto/slh_dsa/slh_hash.h b/crypto/slh_dsa/slh_hash.h index b6e3de77dba..a48b7f2ff4d 100644 --- a/crypto/slh_dsa/slh_hash.h +++ b/crypto/slh_dsa/slh_hash.h @@ -13,6 +13,7 @@ # include # include "slh_adrs.h" +# include "internal/packet.h" # define SLH_HASH_FUNC_DECLARE(ctx, hashf, hashctx) \ const SLH_HASH_FUNC *hashf = ctx->hash_func; \ @@ -40,22 +41,26 @@ typedef struct slh_hash_ctx_st { */ typedef int (OSSL_SLH_HASHFUNC_H_MSG)(SLH_HASH_CTX *ctx, const uint8_t *r, const uint8_t *pk_seed, const uint8_t *pk_root, - const uint8_t *msg, size_t msg_len, uint8_t *out); + const uint8_t *msg, size_t msg_len, uint8_t *out, size_t out_len); typedef int (OSSL_SLH_HASHFUNC_PRF)(SLH_HASH_CTX *ctx, const uint8_t *pk_seed, - const uint8_t *sk_seed, const SLH_ADRS adrs, uint8_t *out); + const uint8_t *sk_seed, const SLH_ADRS adrs, + uint8_t *out, size_t out_len); typedef int (OSSL_SLH_HASHFUNC_PRF_MSG)(SLH_HASH_CTX *ctx, const uint8_t *sk_prf, - const uint8_t *opt_rand, const uint8_t *msg, size_t msg_len, uint8_t *out); + const uint8_t *opt_rand, const uint8_t *msg, size_t msg_len, WPACKET *pkt); typedef int (OSSL_SLH_HASHFUNC_F)(SLH_HASH_CTX *ctx, const uint8_t *pk_seed, - const SLH_ADRS adrs, const uint8_t *m1, size_t m1_len, uint8_t *out); + const SLH_ADRS adrs, const uint8_t *m1, size_t m1_len, + uint8_t *out, size_t out_len); typedef int (OSSL_SLH_HASHFUNC_H)(SLH_HASH_CTX *ctx, const uint8_t *pk_seed, - const SLH_ADRS adrs, const uint8_t *m1, const uint8_t *m2, uint8_t *out); + const SLH_ADRS adrs, const uint8_t *m1, const uint8_t *m2, + uint8_t *out, size_t out_len); typedef int (OSSL_SLH_HASHFUNC_T)(SLH_HASH_CTX *ctx, const uint8_t *pk_seed, - const SLH_ADRS adrs, const uint8_t *m1, size_t m1_len, uint8_t *out); + const SLH_ADRS adrs, const uint8_t *m1, size_t m1_len, + uint8_t *out, size_t out_len); typedef struct slh_hash_func_st { OSSL_SLH_HASHFUNC_H_MSG *H_MSG; diff --git a/crypto/slh_dsa/slh_hypertree.c b/crypto/slh_dsa/slh_hypertree.c index fa16d45b77f..e916fa8eff1 100644 --- a/crypto/slh_dsa/slh_hypertree.c +++ b/crypto/slh_dsa/slh_hypertree.c @@ -11,27 +11,25 @@ #include #include "slh_dsa_local.h" -#define SLH_XMSS_SIG_LEN(n, hm) ((SLH_WOTS_LEN(n) + (hm)) * (n)) - /** * @brief Generate a Hypertree Signature * See FIPS 205 Section 7.1 Algorithm 12 * + * This writes |d| XMSS signatures i.e. ((|h| + |d| * |len|) * |n|) + * * @param ctx Contains SLH_DSA algorithm functions and constants. * @param msg A message of size |n|. * @param sk_seed The private key seed of size |n| * @param pk_seed The public key seed of size |n| * @param tree_id Index of the XMSS tree that will sign the message * @param leaf_id Index of the WOTS+ key within the XMSS tree that will signed the message - * @param sig The returned Hypertree Signature (which is |d| XMSS signatures) - * @param sig_len The size of |sig| which is (|h| + |d| * |len|) * |n|) + * @param sig_wpkt A WPACKET object to write the Hypertree Signature to. * @returns 1 on success, or 0 on error. */ int ossl_slh_ht_sign(SLH_DSA_CTX *ctx, const uint8_t *msg, const uint8_t *sk_seed, const uint8_t *pk_seed, - uint64_t tree_id, uint32_t leaf_id, - uint8_t *sig, size_t sig_len) + uint64_t tree_id, uint32_t leaf_id, WPACKET *sig_wpkt) { SLH_ADRS_FUNC_DECLARE(ctx, adrsf); SLH_ADRS_DECLARE(adrs); @@ -40,8 +38,8 @@ int ossl_slh_ht_sign(SLH_DSA_CTX *ctx, uint32_t n = ctx->params->n; uint32_t d = ctx->params->d; uint32_t hm = ctx->params->hm; - uint8_t *psig = sig; - size_t xmss_sig_len = SLH_XMSS_SIG_LEN(n, hm); + uint8_t *psig; + PACKET rpkt, *xmss_sig_rpkt = &rpkt; mask = (1 << hm) - 1; /* A mod 2^h = A & ((2^h - 1))) */ @@ -49,21 +47,23 @@ int ossl_slh_ht_sign(SLH_DSA_CTX *ctx, memcpy(root, msg, n); for (layer = 0; layer < d; ++layer) { + /* type = SLH_ADRS_TYPE_WOTS_HASH */ adrsf->set_layer_address(adrs, layer); adrsf->set_tree_address(adrs, tree_id); + psig = WPACKET_get_curr(sig_wpkt); if (!ossl_slh_xmss_sign(ctx, root, sk_seed, leaf_id, pk_seed, adrs, - psig, xmss_sig_len)) + sig_wpkt)) + return 0; + if (!PACKET_buf_init(xmss_sig_rpkt, psig, WPACKET_get_curr(sig_wpkt) - psig)) return 0; if (layer < d - 1) { - if (!ossl_slh_xmss_pk_from_sig(ctx, leaf_id, psig, root, - pk_seed, adrs, root)) + if (!ossl_slh_xmss_pk_from_sig(ctx, leaf_id, xmss_sig_rpkt, root, + pk_seed, adrs, root, sizeof(root))) return 0; } - psig += xmss_sig_len; leaf_id = tree_id & mask; tree_id >>= hm; } - assert((size_t)(psig - sig) == sig_len); return 1; } @@ -81,21 +81,19 @@ int ossl_slh_ht_sign(SLH_DSA_CTX *ctx, * * @returns 1 if the computed XMSS public key matches pk_root, or 0 otherwise. */ -int ossl_slh_ht_verify(SLH_DSA_CTX *ctx, const uint8_t *msg, const uint8_t *sig, +int ossl_slh_ht_verify(SLH_DSA_CTX *ctx, const uint8_t *msg, PACKET *sig_pkt, const uint8_t *pk_seed, uint64_t tree_id, uint32_t leaf_id, const uint8_t *pk_root) { SLH_ADRS_FUNC_DECLARE(ctx, adrsf); SLH_ADRS_DECLARE(adrs); uint8_t node[SLH_MAX_N]; - uint32_t layer, len, mask, d, n, tree_height; const SLH_DSA_PARAMS *params = ctx->params; - - tree_height = params->hm; - n = params->n; - d = params->d; - len = SLH_XMSS_SIG_LEN(n, tree_height); - mask = (1 << tree_height) - 1; + uint32_t tree_height = params->hm; + uint32_t n = params->n; + uint32_t d = params->d; + uint32_t mask = (1 << tree_height) - 1; + uint32_t layer; adrsf->zero(adrs); memcpy(node, msg, n); @@ -103,10 +101,9 @@ int ossl_slh_ht_verify(SLH_DSA_CTX *ctx, const uint8_t *msg, const uint8_t *sig, for (layer = 0; layer < d; ++layer) { adrsf->set_layer_address(adrs, layer); adrsf->set_tree_address(adrs, tree_id); - if (!ossl_slh_xmss_pk_from_sig(ctx, leaf_id, sig, node, - pk_seed, adrs, node)) + if (!ossl_slh_xmss_pk_from_sig(ctx, leaf_id, sig_pkt, node, + pk_seed, adrs, node, sizeof(node))) return 0; - sig += len; leaf_id = tree_id & mask; tree_id >>= tree_height; } diff --git a/crypto/slh_dsa/slh_wots.c b/crypto/slh_dsa/slh_wots.c index a0b2766745c..7f56dd45754 100644 --- a/crypto/slh_dsa/slh_wots.c +++ b/crypto/slh_dsa/slh_wots.c @@ -82,32 +82,43 @@ static ossl_inline void compute_checksum_nibbles(const uint8_t *in, size_t in_le * * @param ctx Contains SLH_DSA algorithm functions and constants. * @param in An input string of |n| bytes - * @param n The size of |in| and |pk_seed|_ * @param start_index The chaining start index * @param steps The number of iterations starting from |start_index| * Note |start_index| + |steps| < w * (where w = 16 indicates the length of the hash chains) + * @param pk_seed A public key seed (which is added to the hash) * @param adrs An ADRS object which has a type of WOTS_HASH, and has a layer * address, tree address, key pair address and chain address - * @param pk_seed A public key seed (which is added to the hash) + * @params wpkt A WPACKET object to write the hash chain to (n bytes are written) * @returns 1 on success, or 0 on error. */ static int slh_wots_chain(SLH_DSA_CTX *ctx, const uint8_t *in, uint8_t start_index, uint8_t steps, - const uint8_t *pk_seed, uint8_t *adrs, uint8_t *out) + const uint8_t *pk_seed, uint8_t *adrs, WPACKET *wpkt) { SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx); SLH_ADRS_FUNC_DECLARE(ctx, adrsf); SLH_HASH_FN_DECLARE(hashf, F); SLH_ADRS_FN_DECLARE(adrsf, set_hash_address); - size_t j, end_index = start_index + steps; + size_t j = start_index, end_index; size_t n = ctx->params->n; + uint8_t *tmp; /* Pointer into the |wpkt| buffer */ + size_t tmp_len = n; + + if (steps == 0) + return WPACKET_memcpy(wpkt, in, n); + + if (!WPACKET_allocate_bytes(wpkt, tmp_len, &tmp)) + return 0; - memcpy(out, in, n); + set_hash_address(adrs, j++); + if (!F(hctx, pk_seed, adrs, in, n, tmp, tmp_len)) + return 0; - for (j = start_index; j < end_index; ++j) { + end_index = start_index + steps; + for (; j < end_index; ++j) { set_hash_address(adrs, j); - if (!F(hctx, pk_seed, adrs, out, n, out)) + if (!F(hctx, pk_seed, adrs, tmp, n, tmp, tmp_len)) return 0; } return 1; @@ -123,11 +134,12 @@ static int slh_wots_chain(SLH_DSA_CTX *ctx, const uint8_t *in, * @param adrs An ADRS object containing the layer address, tree address and * keypair address of the WOTS+ public key to generate. * @param pk_out The generated public key of size |n| + * @param pk_out_len The maximum size of |pk_out| * @returns 1 on success, or 0 on error. */ int ossl_slh_wots_pk_gen(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, const uint8_t *pk_seed, - SLH_ADRS adrs, uint8_t *pk_out) + SLH_ADRS adrs, uint8_t *pk_out, size_t pk_out_len) { int ret = 0; SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx); @@ -136,33 +148,38 @@ int ossl_slh_wots_pk_gen(SLH_DSA_CTX *ctx, SLH_ADRS_FN_DECLARE(adrsf, set_chain_address); SLH_ADRS_DECLARE(sk_adrs); SLH_ADRS_DECLARE(wots_pk_adrs); - size_t i, len = 0; size_t n = ctx->params->n; - uint8_t tmp[SLH_WOTS_LEN_MAX * SLH_MAX_N], *ptmp = tmp; - uint8_t sk[32]; + size_t i, len = SLH_WOTS_LEN(n); /*2 * n + 3 */ + uint8_t sk[SLH_MAX_N]; + uint8_t tmp[SLH_WOTS_LEN_MAX * SLH_MAX_N]; + WPACKET pkt, *tmp_wpkt = &pkt; /* Points to the |tmp| buffer */ + size_t tmp_len = 0; + if (!WPACKET_init_static_len(tmp_wpkt, tmp, sizeof(tmp), 0)) + return 0; adrsf->copy(sk_adrs, adrs); adrsf->set_type_and_clear(sk_adrs, SLH_ADRS_TYPE_WOTS_PRF); adrsf->copy_keypair_address(sk_adrs, adrs); - len = SLH_WOTS_LEN(n); /* See Section 5 intro */ - for (i = 0; i < len; ++i) { + for (i = 0; i < len; ++i) { /* len = 2n + 3 */ set_chain_address(sk_adrs, i); - if (!PRF(hctx, pk_seed, sk_seed, sk_adrs, sk)) + if (!PRF(hctx, pk_seed, sk_seed, sk_adrs, sk, sizeof(sk))) goto end; set_chain_address(adrs, i); - if (!slh_wots_chain(ctx, sk, 0, NIBBLE_MASK, pk_seed, adrs, ptmp)) + if (!slh_wots_chain(ctx, sk, 0, NIBBLE_MASK, pk_seed, adrs, tmp_wpkt)) goto end; - ptmp += n; } - len = ptmp - tmp; /* should be n * (2 * n + 3) */ + if (!WPACKET_get_total_written(tmp_wpkt, &tmp_len)) /* should be n * (2 * n + 3) */ + goto end; adrsf->copy(wots_pk_adrs, adrs); adrsf->set_type_and_clear(wots_pk_adrs, SLH_ADRS_TYPE_WOTS_PK); adrsf->copy_keypair_address(wots_pk_adrs, adrs); - ret = hashf->T(hctx, pk_seed, wots_pk_adrs, tmp, len, pk_out); + ret = hashf->T(hctx, pk_seed, wots_pk_adrs, tmp, tmp_len, + pk_out, pk_out_len); end: + WPACKET_finish(tmp_wpkt); OPENSSL_cleanse(tmp, sizeof(tmp)); OPENSSL_cleanse(sk, n); return ret; @@ -172,6 +189,8 @@ end: * @brief WOTS+ Signature generation * See FIPS 205 Section 5.2 Algorithm 7 * + * The returned signature size is len * |n| bytes (where len = 2 * |n| + 3). + * * @param ctx Contains SLH_DSA algorithm functions and constants. * @param msg An input message of size |n| bytes. * The message is either an XMSS or FORS public key @@ -179,14 +198,12 @@ end: * @param pk_seed The public key seed of size |n| bytes * @param adrs An address containing the layer address, tree address and key * pair address. The size is either 32 or 22 bytes. - * @param sig The returned signature. - * @param sig_len The size of |sig| which should be len * |n| bytes. - * (where len = 2 * |n| + 3) + * @param sig_wpkt A WPACKET object to write the signature to. * @returns 1 on success, or 0 on error. */ int ossl_slh_wots_sign(SLH_DSA_CTX *ctx, const uint8_t *msg, const uint8_t *sk_seed, const uint8_t *pk_seed, - SLH_ADRS adrs, uint8_t *sig, size_t sig_len) + SLH_ADRS adrs, WPACKET *sig_wpkt) { int ret = 0; SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx); @@ -198,7 +215,6 @@ int ossl_slh_wots_sign(SLH_DSA_CTX *ctx, const uint8_t *msg, uint8_t sk[SLH_MAX_N]; size_t i, len1, len; size_t n = ctx->params->n; - uint8_t *psig = sig; len1 = SLH_WOTS_LEN1(n); /* 2 * n is for the message length in nibbles */ len = len1 + SLH_WOTS_LEN2; /* 2 * n + 3 (3 checksum nibbles) */ @@ -218,19 +234,16 @@ int ossl_slh_wots_sign(SLH_DSA_CTX *ctx, const uint8_t *msg, for (i = 0; i < len; ++i) { set_chain_address(sk_adrs, i); /* compute chain i secret */ - if (!PRF(hctx, pk_seed, sk_seed, sk_adrs, sk)) + if (!PRF(hctx, pk_seed, sk_seed, sk_adrs, sk, sizeof(sk))) goto err; set_chain_address(adrs, i); /* compute chain i signature */ if (!slh_wots_chain(ctx, sk, 0, msg_and_csum_nibbles[i], - pk_seed, adrs, psig)) + pk_seed, adrs, sig_wpkt)) goto err; - psig += n; } - assert(sig_len == (size_t)(psig - sig)); ret = 1; err: - OPENSSL_cleanse(sk, n); return ret; } @@ -238,30 +251,40 @@ err: * @brief Compute a candidate WOTS+ public key from a message and signature * See FIPS 205 Section 5.3 Algorithm 8 * + * The size of the signature is len * |n| bytes (where len = 2 * |n| + 3). + * * @param ctx Contains SLH_DSA algorithm functions and constants. - * @param sig A WOTS+ signature of size len * |n| bytes. (where len = 2 * |n| + 3) + * @param sig_rpkt A PACKET object to read a WOTS+ signature from * @param msg A message of size |n| bytes. * @param pk_seed The public key seed of size |n|. * @param adrs An ADRS object containing the layer address, tree address and * key pair address that of the WOTS+ key used to sign the message. * @param pk_out The returned public key candidate of size |n| + * @param pk_out_len The maximum size of |pk_out| * @returns 1 on success, or 0 on error. */ int ossl_slh_wots_pk_from_sig(SLH_DSA_CTX *ctx, - const uint8_t *sig, const uint8_t *msg, + PACKET *sig_rpkt, const uint8_t *msg, const uint8_t *pk_seed, uint8_t *adrs, - uint8_t *pk_out) + uint8_t *pk_out, size_t pk_out_len) { + int ret = 0; SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx); SLH_ADRS_FUNC_DECLARE(ctx, adrsf); SLH_ADRS_FN_DECLARE(adrsf, set_chain_address); SLH_ADRS_DECLARE(wots_pk_adrs); uint8_t msg_and_csum_nibbles[SLH_WOTS_LEN_MAX]; - uint8_t tmp[SLH_WOTS_LEN_MAX * SLH_MAX_N], *ptmp = tmp; size_t i, len1, len, n = ctx->params->n; + const uint8_t *sig_i; /* Pointer into |pkt_sig| buffer */ + uint8_t tmp[SLH_WOTS_LEN_MAX * SLH_MAX_N]; + WPACKET pkt, *tmp_pkt = &pkt; + size_t tmp_len = 0; len1 = SLH_WOTS_LEN1(n); - len = len1 + SLH_WOTS_LEN2; + len = len1 + SLH_WOTS_LEN2; /* 2n + 3 */ + + if (!WPACKET_init_static_len(tmp_pkt, tmp, sizeof(tmp), 0)) + return 0; slh_bytes_to_nibbles(msg, n, msg_and_csum_nibbles); compute_checksum_nibbles(msg_and_csum_nibbles, len1, msg_and_csum_nibbles + len1); @@ -269,16 +292,22 @@ int ossl_slh_wots_pk_from_sig(SLH_DSA_CTX *ctx, /* Compute the end nodes for each of the chains */ for (i = 0; i < len; ++i) { set_chain_address(adrs, i); - if (!slh_wots_chain(ctx, sig, msg_and_csum_nibbles[i], - NIBBLE_MASK - msg_and_csum_nibbles[i], - pk_seed, adrs, ptmp)) - return 0; - sig += n; - ptmp += n; + if (!PACKET_get_bytes(sig_rpkt, &sig_i, n) + || !slh_wots_chain(ctx, sig_i, msg_and_csum_nibbles[i], + NIBBLE_MASK - msg_and_csum_nibbles[i], + pk_seed, adrs, tmp_pkt)) + goto err; } /* compress the computed public key value */ adrsf->copy(wots_pk_adrs, adrs); adrsf->set_type_and_clear(wots_pk_adrs, SLH_ADRS_TYPE_WOTS_PK); adrsf->copy_keypair_address(wots_pk_adrs, adrs); - return hashf->T(hctx, pk_seed, wots_pk_adrs, tmp, ptmp - tmp, pk_out); + if (!WPACKET_get_total_written(tmp_pkt, &tmp_len)) + goto err; + ret = hashf->T(hctx, pk_seed, wots_pk_adrs, tmp, tmp_len, + pk_out, pk_out_len); + err: + if (!WPACKET_finish(tmp_pkt)) + ret = 0; + return ret; } diff --git a/crypto/slh_dsa/slh_xmss.c b/crypto/slh_dsa/slh_xmss.c index 08f112f0d2f..eb9125aec33 100644 --- a/crypto/slh_dsa/slh_xmss.c +++ b/crypto/slh_dsa/slh_xmss.c @@ -18,22 +18,24 @@ * the hash of each parent using 2 child nodes. * * @param ctx Contains SLH_DSA algorithm functions and constants. - * @param sk_seed A private key seed of size |n| + * @param sk_seed A SLH-DSA private key seed of size |n| * @param nodeid The index of the target node being computed * (which must be < 2^(hm - height) * @param h The height within the tree of the node being computed. * (which must be <= hm) (hm is one of 3, 4, 8 or 9) * At height=0 There are 2^hm leaf nodes, * and the root node is at height = hm) - * @param pk_seed A public key seed of size |n| + * @param pk_seed A SLH-DSA public key seed of size |n| * @param adrs An ADRS object containing the layer address and tree address set * to the XMSS tree within which the XMSS tree is being computed. * @param pk_out The generated public key of size |n| + * @param pk_out_len The maximum size of |pk_out| * @returns 1 on success, or 0 on error. */ int ossl_slh_xmss_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, uint32_t node_id, uint32_t h, - const uint8_t *pk_seed, SLH_ADRS adrs, uint8_t *pk_out) + const uint8_t *pk_seed, SLH_ADRS adrs, + uint8_t *pk_out, size_t pk_out_len) { SLH_ADRS_FUNC_DECLARE(ctx, adrsf); @@ -41,20 +43,22 @@ int ossl_slh_xmss_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, /* For leaf nodes generate the public key */ adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_WOTS_HASH); adrsf->set_keypair_address(adrs, node_id); - if (!ossl_slh_wots_pk_gen(ctx, sk_seed, pk_seed, adrs, pk_out)) + if (!ossl_slh_wots_pk_gen(ctx, sk_seed, pk_seed, adrs, + pk_out, pk_out_len)) return 0; } else { uint8_t lnode[SLH_MAX_N], rnode[SLH_MAX_N]; if (!ossl_slh_xmss_node(ctx, sk_seed, 2 * node_id, h - 1, pk_seed, adrs, - lnode) + lnode, sizeof(lnode)) || !ossl_slh_xmss_node(ctx, sk_seed, 2 * node_id + 1, h - 1, - pk_seed, adrs, rnode)) + pk_seed, adrs, rnode, sizeof(rnode))) return 0; adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_TREE); adrsf->set_tree_height(adrs, h); adrsf->set_tree_index(adrs, node_id); - if (!ctx->hash_func->H(&ctx->hash_ctx, pk_seed, adrs, lnode, rnode, pk_out)) + if (!ctx->hash_func->H(&ctx->hash_ctx, pk_seed, adrs, lnode, rnode, + pk_out, pk_out_len)) return 0; } return 1; @@ -64,6 +68,10 @@ int ossl_slh_xmss_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, * @brief Generate an XMSS signature using a message and key. * See FIPS 205 Section 6.2 Algorithm 10 * + * The generated signature consists of: + * - A WOTS+ signature of size (2 * n + 3) * n + * - An array of authentication paths of size (XMSS tree_height) * n. + * * @param ctx Contains SLH_DSA algorithm functions and constants. * @param msg A message of size |n| bytes to sign * @param sk_seed A private key seed of size |n| @@ -71,56 +79,66 @@ int ossl_slh_xmss_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed, * @param pk_seed A public key seed f size |n| * @param adrs An ADRS object containing the layer address and tree address set * to the XMSS key being used to sign the message. - * @param sig The generated XMSS signature. - * @param sig_len The size of |sig|. which consists of a WOTS+ - * signature of size [2 * n + 3][n] followed by an authentication - * path of size [tree_height[n]. + * @param sig_wpkt A WPACKET object to write the generated XMSS signature to. * @returns 1 on success, or 0 on error. */ int ossl_slh_xmss_sign(SLH_DSA_CTX *ctx, const uint8_t *msg, const uint8_t *sk_seed, uint32_t node_id, - const uint8_t *pk_seed, SLH_ADRS adrs, - uint8_t *sig, size_t sig_len) + const uint8_t *pk_seed, SLH_ADRS adrs, WPACKET *sig_wpkt) { SLH_ADRS_FUNC_DECLARE(ctx, adrsf); - uint32_t h, id = node_id; + SLH_ADRS_DECLARE(tmp_adrs); size_t n = ctx->params->n; - uint32_t hm = ctx->params->hm; - size_t wots_sig_len = n * SLH_WOTS_LEN(n); - uint8_t *auth_path = sig + wots_sig_len; + uint32_t h, hm = ctx->params->hm; + uint32_t id = node_id; + uint8_t *auth_path; /* Pointer to a buffer offset inside |sig_wpkt| */ + size_t auth_path_len = n; + /* + * This code reverses the order of the FIPS 205 code so that it does the + * sign first. This simplifies the WPACKET writing. + */ + adrsf->copy(tmp_adrs, adrs); + adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_WOTS_HASH); + adrsf->set_keypair_address(adrs, node_id); + if (!ossl_slh_wots_sign(ctx, msg, sk_seed, pk_seed, adrs, sig_wpkt)) + return 0; + + adrsf->copy(adrs, tmp_adrs); for (h = 0; h < hm; ++h) { - if (!ossl_slh_xmss_node(ctx, sk_seed, id ^ 1, h, pk_seed, adrs, auth_path)) + if (!WPACKET_allocate_bytes(sig_wpkt, auth_path_len, &auth_path) + || !ossl_slh_xmss_node(ctx, sk_seed, id ^ 1, h, pk_seed, adrs, + auth_path, auth_path_len)) return 0; id >>= 1; - auth_path += n; } - adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_WOTS_HASH); - adrsf->set_keypair_address(adrs, node_id); - return ossl_slh_wots_sign(ctx, msg, sk_seed, pk_seed, adrs, sig, wots_sig_len); + return 1; } /** * @brief Compute a candidate XMSS public key from a message and XMSS signature * See FIPS 205 Section 6.3 Algorithm 11 * + * * The signature consists of: + * - A WOTS+ signature of size (2 * n + 3) * n + * - An array of authentication paths of size (XMSS tree height) * n. + * * @param ctx Contains SLH_DSA algorithm functions and constants. * @param node_id Must be set to the |node_id| used in xmss_sign(). - * @param sig A XMSS signature which consists of a WOTS+ signature of - * [2 * n + 3][n] bytes followed by an authentication path of - * [hm][n] bytes (where hm is the height of the XMSS tree). + * @param sig_rpkt A Packet to read a XMSS signature from. * @param msg A message of size |n| bytes * @param sk_seed A private key seed of size |n| * @param pk_seed A public key seed of size |n| * @param adrs An ADRS object containing a layer address and tree address of an * XMSS key used for signing the message. * @param pk_out The returned candidate XMSS public key of size |n|. + * @param pk_out_len The maximum size of |pk_out|. * @returns 1 on success, or 0 on error. */ int ossl_slh_xmss_pk_from_sig(SLH_DSA_CTX *ctx, uint32_t node_id, - const uint8_t *sig, const uint8_t *msg, + PACKET *sig_rpkt, const uint8_t *msg, const uint8_t *pk_seed, SLH_ADRS adrs, - uint8_t *pk_out) + uint8_t *pk_out, size_t pk_out_len) { SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx); SLH_HASH_FN_DECLARE(hashf, H); @@ -130,31 +148,32 @@ int ossl_slh_xmss_pk_from_sig(SLH_DSA_CTX *ctx, uint32_t node_id, uint32_t k; size_t n = ctx->params->n; uint32_t hm = ctx->params->hm; - size_t wots_sig_len = n * SLH_WOTS_LEN(n); - const uint8_t *auth_path = sig + wots_sig_len; uint8_t *node = pk_out; + const uint8_t *auth_path; /* Pointer to buffer offset in |pkt_sig| */ adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_WOTS_HASH); adrsf->set_keypair_address(adrs, node_id); - if (!ossl_slh_wots_pk_from_sig(ctx, sig, msg, pk_seed, adrs, node)) + if (!ossl_slh_wots_pk_from_sig(ctx, sig_rpkt, msg, pk_seed, adrs, + node, pk_out_len)) return 0; adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_TREE); for (k = 0; k < hm; ++k) { + if (!PACKET_get_bytes(sig_rpkt, &auth_path, n)) + return 0; set_tree_height(adrs, k + 1); if ((node_id & 1) == 0) { /* even */ node_id >>= 1; set_tree_index(adrs, node_id); - if (!H(hctx, pk_seed, adrs, node, auth_path, node)) + if (!H(hctx, pk_seed, adrs, node, auth_path, node, pk_out_len)) return 0; } else { /* odd */ node_id = (node_id - 1) >> 1; set_tree_index(adrs, node_id); - if (!H(hctx, pk_seed, adrs, auth_path, node, node)) + if (!H(hctx, pk_seed, adrs, auth_path, node, node, pk_out_len)) return 0; } - auth_path += n; } return 1; } diff --git a/doc/designs/slh-dsa.md b/doc/designs/slh-dsa.md index b1c5d014f74..4cca1f2b572 100644 --- a/doc/designs/slh-dsa.md +++ b/doc/designs/slh-dsa.md @@ -101,11 +101,12 @@ EVP_PKEY_verify_message_init(), EVP_PKEY_verify(). Buffers ------- -Many functions need to pass around key elements and return signature buffers of -various sizes which are often updated in loops in parts, all of these sizes -are known quantities. Currently there is no attempt to use wpacket to pass -around these sizes. asserts are currently done by the child functions to check -that the expected size does not exceed the size passed in by the parent. +There are many functions pass buffers of size |n| Where n is one of 16,24,32 +depending on the algorithm name. These are used for key elements and hashes, so +PACKETS are not used for these. + +Where it makes sense to, WPACKET is used for output (such as signature generation) +and PACKET for reading signature data. Constant Time Considerations ---------------------------- diff --git a/test/slh_dsa_test.c b/test/slh_dsa_test.c index 1926b637a5a..ef026e2d350 100644 --- a/test/slh_dsa_test.c +++ b/test/slh_dsa_test.c @@ -243,7 +243,6 @@ static int slh_dsa_sign_verify_test(int tst_id) || !TEST_int_eq(EVP_PKEY_sign(sctx, psig, &psig_len, td->msg, td->msg_len), 1)) goto err; - if (!TEST_int_eq(EVP_Q_digest(lib_ctx, "SHA256", NULL, psig, psig_len, digest, &digest_len), 1)) goto err;