]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Update SLH-DSA code to use PACKET and WPACKET.
authorslontis <shane.lontis@oracle.com>
Tue, 12 Nov 2024 07:35:10 +0000 (18:35 +1100)
committerTomas Mraz <tomas@openssl.org>
Tue, 18 Feb 2025 09:17:29 +0000 (10:17 +0100)
Reviewed-by: Paul Dale <ppzgs1@gmail.com>
Reviewed-by: Viktor Dukhovni <viktor@openssl.org>
Reviewed-by: Tim Hudson <tjh@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/25882)

crypto/slh_dsa/slh_dsa.c
crypto/slh_dsa/slh_dsa_key.c
crypto/slh_dsa/slh_dsa_local.h
crypto/slh_dsa/slh_fors.c
crypto/slh_dsa/slh_hash.c
crypto/slh_dsa/slh_hash.h
crypto/slh_dsa/slh_hypertree.c
crypto/slh_dsa/slh_wots.c
crypto/slh_dsa/slh_xmss.c
doc/designs/slh-dsa.md
test/slh_dsa_test.c

index 10b0b1f66b011928ab43764c078e4fc0a95274fd..0fdef850349e9a2f1dade36ed895c626d69d8c5d 100644 (file)
 #include "slh_dsa_local.h"
 #include "slh_dsa_key.h"
 
-#define SLH_MAX_M 49
+#define SLH_MAX_M 49 /* See slh_params.c */
+/* The size of md is (21..40 bytes) - since a is in bits round up to nearest byte */
+#define MD_LEN(params) (((params)->k * (params)->a + 7) >> 3)
 
-/* (n + SLH_SIG_FORS_LEN(k, a, n) + SLH_SIG_HT_LEN(n, hm, d)) */
-#define SLH_SIG_RANDOM_LEN(n)      (n)
-#define SLH_SIG_FORS_LEN(k, a, n)  (n) * ((k) * (1 + (a)))
-#define SLH_SIG_HT_LEN(h, d, n)    (n) * ((h) + (d) * SLH_WOTS_LEN(n))
-
-static void get_tree_ids(const uint8_t *digest, const SLH_DSA_PARAMS *params,
-                         uint64_t *tree_id, uint32_t *leaf_id);
+static int get_tree_ids(PACKET *pkt, const SLH_DSA_PARAMS *params,
+                        uint64_t *tree_id, uint32_t *leaf_id);
 
 /**
  * @brief SLH-DSA Signature generation
@@ -46,23 +43,23 @@ static int slh_sign_internal(SLH_DSA_CTX *ctx, const SLH_DSA_KEY *priv,
                              uint8_t *sig, size_t *sig_len, size_t sig_size,
                              const uint8_t *opt_rand)
 {
+    int ret = 0;
     const SLH_DSA_PARAMS *params = ctx->params;
-    uint32_t n = params->n;
-    size_t r_len = n;
-    size_t sig_fors_len = SLH_SIG_FORS_LEN(params->k, params->a, n);
-    size_t sig_ht_len = SLH_SIG_HT_LEN(params->h, params->d, n);
-    size_t sig_len_expected = r_len + sig_fors_len + sig_ht_len;
+    size_t sig_len_expected = params->sig_len;
     SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx);
     SLH_ADRS_FUNC_DECLARE(ctx, adrsf);
     SLH_ADRS_DECLARE(adrs);
+    uint8_t m_digest[SLH_MAX_M];
+    const uint8_t *md; /* The first md_len bytes of m_digest */
+    size_t md_len = MD_LEN(params); /* The size of the digest |md| */
+    /* Points to |m_digest| buffer, it is also reused to point to |sig_fors| */
+    PACKET r_packet, *rpkt = &r_packet;
+    uint8_t *r, *sig_fors; /* Pointers into buffer inside |wpkt| */
+    WPACKET w_packet, *wpkt = &w_packet; /* Points to output |sig| buffer */
+    const uint8_t *pk_seed, *sk_seed; /* pointers to elements within |priv| */
+    uint8_t pk_fors[SLH_MAX_N];
     uint64_t tree_id;
     uint32_t leaf_id;
-    uint8_t pk_fors[SLH_MAX_N];
-    uint8_t m_digest[SLH_MAX_M];
-    uint8_t *r = sig;
-    uint8_t *sig_fors = r + r_len;
-    uint8_t *sig_ht = sig_fors + sig_fors_len;
-    const uint8_t *md, *pk_seed, *sk_seed;
 
     if (sig_len != NULL)
         *sig_len = sig_len_expected;
@@ -76,6 +73,11 @@ static int slh_sign_internal(SLH_DSA_CTX *ctx, const SLH_DSA_KEY *priv,
     if (priv->has_priv == 0)
         return 0;
 
+    if (!WPACKET_init_static_len(wpkt, sig, sig_len_expected, 0))
+        return 0;
+    if (!PACKET_buf_init(rpkt, m_digest, params->m))
+        return 0;
+
     pk_seed = SLH_DSA_PK_SEED(priv);
     sk_seed = SLH_DSA_SK_SEED(priv);
 
@@ -83,26 +85,38 @@ static int slh_sign_internal(SLH_DSA_CTX *ctx, const SLH_DSA_KEY *priv,
         opt_rand = pk_seed;
 
     adrsf->zero(adrs);
-    /* calculate Randomness value r, and output to the signature */
-    if (!hashf->PRF_MSG(hctx, SLH_DSA_SK_PRF(priv), opt_rand, msg, msg_len, r)
+    /* calculate Randomness value r, and output to the SLH-DSA signature */
+    r = WPACKET_get_curr(wpkt);
+    if (!hashf->PRF_MSG(hctx, SLH_DSA_SK_PRF(priv), opt_rand, msg, msg_len, wpkt)
             /* generate a digest of size |params->m| bytes where m is (30..49) */
             || !hashf->H_MSG(hctx, r, pk_seed, SLH_DSA_PK_ROOT(priv), msg, msg_len,
-                             m_digest))
-        return 0;
-    /* Grab selected bytes from the digest to select tree and leaf id's */
-    get_tree_ids(m_digest, params, &tree_id, &leaf_id);
+                             m_digest, sizeof(m_digest))
+            /* Grab the first md_len bytes of m_digest to use in fors_sign() */
+            || !PACKET_get_bytes(rpkt, &md, md_len)
+            /* Grab remaining bytes from m_digest to select tree and leaf id's */
+            || !get_tree_ids(rpkt, params, &tree_id, &leaf_id))
+        goto err;
 
     adrsf->set_tree_address(adrs, tree_id);
     adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_FORS_TREE);
     adrsf->set_keypair_address(adrs, leaf_id);
 
-    /* generate the FORS signature and append it to the signature */
-    md = m_digest;
-    return ossl_slh_fors_sign(ctx, md, sk_seed, pk_seed, adrs, sig_fors, sig_fors_len)
-        /* Calculate the FORS public key */
-        && ossl_slh_fors_pk_from_sig(ctx, sig_fors, md, pk_seed, adrs, pk_fors)
+    sig_fors = WPACKET_get_curr(wpkt);
+    /* generate the FORS signature and append it to the SLH-DSA signature */
+    ret = ossl_slh_fors_sign(ctx, md, sk_seed, pk_seed, adrs, wpkt)
+        /* Reuse rpkt to point to the FORS signature that was just generated */
+        && PACKET_buf_init(rpkt, sig_fors, WPACKET_get_curr(wpkt) - sig_fors)
+        /* Calculate the FORS public key using the generated FORS signature */
+        && ossl_slh_fors_pk_from_sig(ctx, rpkt, md, pk_seed, adrs,
+                                     pk_fors, sizeof(pk_fors))
+        /* Generate ht signature and append to the SLH-DSA signature */
         && ossl_slh_ht_sign(ctx, pk_fors, sk_seed, pk_seed, tree_id, leaf_id,
-                            sig_ht, sig_ht_len);
+                            wpkt);
+    ret = 1;
+ err:
+    if (!WPACKET_finish(wpkt))
+        ret = 0;
+    return ret;
 }
 
 /**
@@ -129,43 +143,57 @@ static int slh_verify_internal(SLH_DSA_CTX *ctx, const SLH_DSA_KEY *pub,
     SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx);
     SLH_ADRS_FUNC_DECLARE(ctx, adrsf);
     SLH_ADRS_DECLARE(adrs);
-    uint8_t mdigest[SLH_MAX_M];
+    const SLH_DSA_PARAMS *params = ctx->params;
+    uint32_t n = params->n;
+    const uint8_t *pk_seed, *pk_root; /* Pointers to elements in |pub| */
+    PACKET pkt, *sig_rpkt = &pkt; /* Points to the |sig| buffer */
+    uint8_t m_digest[SLH_MAX_M];
+    const uint8_t *md; /* This is a pointer into the buffer in m_digest_rpkt */
+    size_t md_len = MD_LEN(params); /* 21..40 bytes */
+    PACKET pkt2, *m_digest_rpkt = &pkt2; /* Points to m_digest buffer */
+    const uint8_t *r; /* Pointer to |sig_rpkt| buffer */
     uint8_t pk_fors[SLH_MAX_N];
     uint64_t tree_id;
     uint32_t leaf_id;
-    const SLH_DSA_PARAMS *params = ctx->params;
-    uint32_t n = params->n;
-    size_t r_len = SLH_SIG_RANDOM_LEN(n);
-    size_t sig_fors_len = SLH_SIG_FORS_LEN(params->k, params->a, n);
-    size_t sig_ht_len = SLH_SIG_HT_LEN(params->h, params->d, n);
-    const uint8_t *r, *sig_fors, *sig_ht, *md, *pk_seed, *pk_root;
 
-    if (sig_len != (r_len + sig_fors_len + sig_ht_len))
-        return 0;
     /* Exit if public key is not set */
     if (pub->key_len == 0)
         return 0;
 
-    adrsf->zero(adrs);
+    /* Exit if signature is invalid size */
+    if (sig_len != ctx->params->sig_len
+            || !PACKET_buf_init(sig_rpkt, sig, sig_len))
+        return 0;
+    if (!PACKET_get_bytes(sig_rpkt, &r, n))
+        return 0;
 
-    r = sig;
-    sig_fors = r + r_len;
-    sig_ht = sig_fors + sig_fors_len;
+    adrsf->zero(adrs);
 
     pk_seed = SLH_DSA_PK_SEED(pub);
     pk_root = SLH_DSA_PK_ROOT(pub);
 
-    if (!hashf->H_MSG(hctx, r, pk_seed, pk_root, msg, msg_len, mdigest))
+    if (!hashf->H_MSG(hctx, r, pk_seed, pk_root, msg, msg_len,
+                      m_digest, sizeof(m_digest)))
+        return 0;
+
+    /*
+     * Get md (the first md_len bytes of m_digest to use in
+     * ossl_slh_fors_pk_from_sig(), and then retrieve the tree id and leaf id
+     * from the remaining bytes in m_digest.
+     */
+    if (!PACKET_buf_init(m_digest_rpkt, m_digest, sizeof(m_digest))
+            || !PACKET_get_bytes(m_digest_rpkt, &md, md_len)
+            || !get_tree_ids(m_digest_rpkt, params, &tree_id, &leaf_id))
         return 0;
-    md = mdigest;
-    get_tree_ids(mdigest, params, &tree_id, &leaf_id);
 
     adrsf->set_tree_address(adrs, tree_id);
     adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_FORS_TREE);
     adrsf->set_keypair_address(adrs, leaf_id);
-    return ossl_slh_fors_pk_from_sig(ctx, sig_fors, md, pk_seed, adrs, pk_fors)
-        && ossl_slh_ht_verify(ctx, pk_fors, sig_ht, pk_seed,
-                              tree_id, leaf_id, pk_root);
+    return ossl_slh_fors_pk_from_sig(ctx, sig_rpkt, md, pk_seed, adrs,
+                                     pk_fors, sizeof(pk_fors))
+        && ossl_slh_ht_verify(ctx, pk_fors, sig_rpkt, pk_seed,
+                              tree_id, leaf_id, pk_root)
+        && PACKET_remaining(sig_rpkt) == 0;
 }
 
 /**
@@ -288,21 +316,20 @@ static uint64_t bytes_to_u64_be(const uint8_t *in, size_t in_len)
  * Converts digested bytes into a tree index, and leaf index within the tree.
  * The sizes are determined by the |params| parameter set.
  */
-static void get_tree_ids(const uint8_t *digest, const SLH_DSA_PARAMS *params,
-                         uint64_t *tree_id, uint32_t *leaf_id)
+static int get_tree_ids(PACKET *rpkt, const SLH_DSA_PARAMS *params,
+                        uint64_t *tree_id, uint32_t *leaf_id)
 {
     const uint8_t *tree_id_bytes, *leaf_id_bytes;
-    uint32_t md_len, tree_id_len, leaf_id_len;
+    uint32_t tree_id_len, leaf_id_len;
     uint64_t tree_id_mask, leaf_id_mask;
 
-    md_len = ((params->k * params->a + 7) >> 3); /* 21..40 bytes */
     tree_id_len = ((params->h - params->hm + 7) >> 3); /* 7 or 8 bytes */
     leaf_id_len = ((params->hm + 7) >> 3); /* 1 or 2 bytes */
 
-    tree_id_bytes = digest + md_len;
-    leaf_id_bytes = tree_id_bytes + tree_id_len;
+    if (!PACKET_get_bytes(rpkt, &tree_id_bytes, tree_id_len)
+            || !PACKET_get_bytes(rpkt, &leaf_id_bytes, leaf_id_len))
+        return 0;
 
-    assert((md_len + tree_id_len + leaf_id_len) == params->m);
     /*
      * In order to calculate A mod (2^X) where X is in the range of (54..64)
      * This is equivalent to A & (2^x - 1) which is just a sequence of X ones
@@ -315,4 +342,5 @@ static void get_tree_ids(const uint8_t *digest, const SLH_DSA_PARAMS *params,
     leaf_id_mask = (1 << params->hm) - 1; /* max value is 0x1FF when hm = 9 */
     *tree_id = bytes_to_u64_be(tree_id_bytes, tree_id_len) & tree_id_mask;
     *leaf_id = (uint32_t)(bytes_to_u64_be(leaf_id_bytes, leaf_id_len) & leaf_id_mask);
+    return 1;
 }
index e189c8e55e4fa6e8a0b1a5531e4d8bbf7db1ea0f..cf4f5c8787b3f06043d32a88c79738ea259a7e1c 100644 (file)
@@ -204,13 +204,12 @@ static int slh_dsa_compute_pk_root(SLH_DSA_CTX *ctx, SLH_DSA_KEY *out)
     SLH_ADRS_DECLARE(adrs);
     const SLH_DSA_PARAMS *params = out->params;
 
-    assert(params != NULL);
-
     adrsf->zero(adrs);
     adrsf->set_layer_address(adrs, params->d - 1);
     /* Generate the ROOT public key */
     return ossl_slh_xmss_node(ctx, SLH_DSA_SK_SEED(out), 0, params->hm,
-                              SLH_DSA_PK_SEED(out), adrs, SLH_DSA_PK_ROOT(out));
+                              SLH_DSA_PK_SEED(out), adrs,
+                              SLH_DSA_PK_ROOT(out), params->n);
 }
 
 /**
index 59c5fb605d58d92294c7103c2778020598d8f6be..8b4a40941d68eaaba2f62328127287b014615dea 100644 (file)
@@ -46,42 +46,43 @@ struct slh_dsa_ctx_st {
     SLH_HASH_CTX hash_ctx;
 };
 
-__owur int ossl_slh_wots_pk_gen(SLH_DSA_CTX *ctx,
-                                const uint8_t *sk_seed, const uint8_t *pk_seed,
-                                SLH_ADRS adrs, uint8_t *pk_out);
+__owur int ossl_slh_wots_pk_gen(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
+                                const uint8_t *pk_seed, SLH_ADRS adrs,
+                                uint8_t *pk_out, size_t pk_out_len);
 __owur int ossl_slh_wots_sign(SLH_DSA_CTX *ctx, const uint8_t *msg,
                              const uint8_t *sk_seed, const uint8_t *pk_seed,
-                             SLH_ADRS adrs, uint8_t *sig, size_t sig_len);
+                             SLH_ADRS adrs, WPACKET *sig_wpkt);
 __owur int ossl_slh_wots_pk_from_sig(SLH_DSA_CTX *ctx,
-                                     const uint8_t *sig, const uint8_t *msg,
-                                     const uint8_t *pk_seed, uint8_t *adrs,
-                                     uint8_t *pk_out);
+                                     PACKET *sig_rpkt, const uint8_t *msg,
+                                     const uint8_t *pk_seed, SLH_ADRS adrs,
+                                     uint8_t *pk_out, size_t pk_out_len);
 
 __owur int ossl_slh_xmss_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
                               uint32_t node_id, uint32_t height,
                               const uint8_t *pk_seed, SLH_ADRS adrs,
-                              uint8_t *pk_out);
+                              uint8_t *pk_out, size_t pk_out_len);
 __owur int ossl_slh_xmss_sign(SLH_DSA_CTX *ctx, const uint8_t *msg,
                               const uint8_t *sk_seed, uint32_t node_id,
                               const uint8_t *pk_seed, SLH_ADRS adrs,
-                              uint8_t *sig, size_t sig_len);
+                              WPACKET *sig_wpkt);
 __owur int ossl_slh_xmss_pk_from_sig(SLH_DSA_CTX *ctx, uint32_t node_id,
-                                     const uint8_t *sig, const uint8_t *msg,
+                                     PACKET *sig_rpkt, const uint8_t *msg,
                                      const uint8_t *pk_seed, SLH_ADRS adrs,
-                                     uint8_t *pk_out);
+                                     uint8_t *pk_out, size_t pk_out_len);
 
 __owur int ossl_slh_ht_sign(SLH_DSA_CTX *ctx, const uint8_t *msg,
                             const uint8_t *sk_seed, const uint8_t *pk_seed,
                             uint64_t tree_id, uint32_t leaf_id,
-                            uint8_t *sig_out, size_t sig_out_len);
+                            WPACKET *sig_wpkt);
 __owur int ossl_slh_ht_verify(SLH_DSA_CTX *ctx, const uint8_t *msg,
-                              const uint8_t *sig, const uint8_t *pk_seed,
+                              PACKET *sig_rpkt, const uint8_t *pk_seed,
                               uint64_t tree_id, uint32_t leaf_id,
                               const uint8_t *pk_root);
 
 __owur int ossl_slh_fors_sign(SLH_DSA_CTX *ctx, const uint8_t *md,
                               const uint8_t *sk_seed, const uint8_t *pk_seed,
-                              SLH_ADRS adrs, uint8_t *sig, size_t sig_len);
-__owur int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, const uint8_t *sig,
+                              SLH_ADRS adrs, WPACKET *sig_wpkt);
+__owur int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, PACKET *sig_rpkt,
                                      const uint8_t *md, const uint8_t *pk_seed,
-                                     SLH_ADRS adrs, uint8_t *pk_out);
+                                     SLH_ADRS adrs,
+                                     uint8_t *pk_out, size_t pk_out_len);
index 499a5ff0676914b6104114b578101248b3de69bc..e3f8655b1cbedb3b0c6a48c7fa3be76044d5b13a 100644 (file)
@@ -35,11 +35,12 @@ static void slh_base_2b(const uint8_t *in, uint32_t b, uint32_t *out, size_t out
  * @param id The index of the FORS secret value within the sets of FORS trees.
  *               (which must be < 2^(hm - height)
  * @param pk_out The generated FORS secret value of size |n|
+ * @param pk_out_len The maximum size of |pk_out|
  * @returns 1 on success, or 0 on error.
  */
 static int slh_fors_sk_gen(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
                            const uint8_t *pk_seed, SLH_ADRS adrs, uint32_t id,
-                           uint8_t *sk_out)
+                           uint8_t *pk_out, size_t pk_out_len)
 {
     SLH_ADRS_DECLARE(sk_adrs);
     SLH_ADRS_FUNC_DECLARE(ctx, adrsf);
@@ -48,7 +49,8 @@ static int slh_fors_sk_gen(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
     adrsf->set_type_and_clear(sk_adrs, SLH_ADRS_TYPE_FORS_PRF);
     adrsf->copy_keypair_address(sk_adrs, adrs);
     adrsf->set_tree_index(sk_adrs, id);
-    return ctx->hash_func->PRF(&ctx->hash_ctx, pk_seed, sk_seed, sk_adrs, sk_out);
+    return ctx->hash_func->PRF(&ctx->hash_ctx, pk_seed, sk_seed, sk_adrs,
+                               pk_out, pk_out_len);
 }
 
 /**
@@ -69,11 +71,12 @@ static int slh_fors_sk_gen(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
  * @param node_id The target node index
  * @param height The target node height
  * @param node The returned hash for a node of size|n|
+ * @param node_len The maximum size of |node|
  * @returns 1 on success, or 0 on error.
  */
 static int slh_fors_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
                          const uint8_t *pk_seed, SLH_ADRS adrs, uint32_t node_id,
-                         uint32_t height, uint8_t *node)
+                         uint32_t height, uint8_t *node, size_t node_len)
 {
     int ret = 0;
     SLH_ADRS_FUNC_DECLARE(ctx, adrsf);
@@ -81,22 +84,26 @@ static int slh_fors_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
     uint32_t n = ctx->params->n;
 
     if (height == 0) {
-        if (!slh_fors_sk_gen(ctx, sk_seed, pk_seed, adrs, node_id, sk))
+        /* Gets here for leaf nodes */
+        if (!slh_fors_sk_gen(ctx, sk_seed, pk_seed, adrs, node_id,
+                             sk, sizeof(sk)))
             return 0;
         adrsf->set_tree_height(adrs, 0);
         adrsf->set_tree_index(adrs, node_id);
-        ret = ctx->hash_func->F(&ctx->hash_ctx, pk_seed, adrs, sk, n, node);
+        ret = ctx->hash_func->F(&ctx->hash_ctx, pk_seed, adrs, sk, n,
+                                node, node_len);
         OPENSSL_cleanse(sk, n);
         return ret;
     } else {
         if (!slh_fors_node(ctx, sk_seed, pk_seed, adrs, 2 * node_id, height - 1,
-                           lnode)
+                           lnode, sizeof(rnode))
                 || !slh_fors_node(ctx, sk_seed, pk_seed, adrs, 2 * node_id + 1,
-                                  height - 1, rnode))
+                                  height - 1, rnode, sizeof(rnode)))
             return 0;
         adrsf->set_tree_height(adrs, height);
         adrsf->set_tree_index(adrs, node_id);
-        if (!ctx->hash_func->H(&ctx->hash_ctx, pk_seed, adrs, lnode, rnode, node))
+        if (!ctx->hash_func->H(&ctx->hash_ctx, pk_seed, adrs, lnode, rnode,
+                               node, node_len))
             return 0;
     }
     return 1;
@@ -106,6 +113,10 @@ static int slh_fors_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
  * @brief Generate an FORS signature
  * See FIPS 205 Section 8.3 Algorithm 16
  *
+ * A FORS signature has a size of (k * (1 + a) * n) bytes
+ * There are k trees, each of which have a private key value of size |n| followed
+ * by an authentication path of size |a| (where each path is size |n|)
+ *
  * @param ctx Contains SLH_DSA algorithm functions and constants.
  * @param md A message digest of size |(k * a + 7) / 8| bytes to sign
  * @param sk_seed A private key seed of size |n|
@@ -114,24 +125,23 @@ static int slh_fors_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
  *             tree address set to the XMSS tree that signs the FORS key,
  *             the type set to FORS_TREE, and the keypair address set to the
  *             index of the WOTS+ key that signs the FORS key.
- * @param sig_out The generated XMSS signature which consists of a WOTS+
- *                signature and authentication path
+ * @param sig_wpkt A WPACKET object to write the generated XMSS signature to
  * @param sig_len  The size of |sig| which is (2 * n + 3) * n + tree_height * n.
  * @returns 1 on success, or 0 on error.
  */
 int ossl_slh_fors_sign(SLH_DSA_CTX *ctx, const uint8_t *md,
                        const uint8_t *sk_seed, const uint8_t *pk_seed,
-                       SLH_ADRS adrs, uint8_t *sig, size_t sig_len)
+                       SLH_ADRS adrs, WPACKET *sig_wpkt)
 {
-    uint32_t i, j, s;
+    uint32_t tree_id, layer, s, tree_offset;
     uint32_t ids[SLH_MAX_K];
     const SLH_DSA_PARAMS *params = ctx->params;
     uint32_t n = params->n;
-    uint32_t k = params->k;
+    uint32_t k = params->k; /* number of trees */
     uint32_t a = params->a;
-    uint32_t t = (1 << a);
-    uint32_t t_times_i = 0;
-    uint8_t *psig = sig;
+    uint32_t two_power_a = (1 << a); /* this is t in FIPS 205 */
+    uint32_t tree_id_times_two_power_a = 0;
+    uint8_t out[SLH_MAX_N];
 
     /*
      * Split md into k a-bit values e.g with k = 14, a = 12
@@ -139,25 +149,42 @@ int ossl_slh_fors_sign(SLH_DSA_CTX *ctx, const uint8_t *md,
      */
     slh_base_2b(md, a, ids, k);
 
-    for (i = 0; i < k; ++i) {
-        uint32_t id = ids[i]; /* |id| = |a| bits */
+    for (tree_id = 0; tree_id < k; ++tree_id) {
+        /* Get the tree[i] leaf id */
+        uint32_t node_id = ids[tree_id]; /* |id| = |a| bits */
+
+        /*
+         * Give each of the k trees a unique range at each level.
+         * e.g. If we have 4096 leaf nodes (2^a = 2^12) for each tree
+         * tree i will use indexes from 4096 * i + (0..4095) for its bottom level.
+         * For the next level up from the bottom there would be 2048 nodes
+         * (so tree i uses indexes 2048 * i + (0...2047) for this level)
+         */
+        tree_offset = tree_id_times_two_power_a;
 
         if (!slh_fors_sk_gen(ctx, sk_seed, pk_seed, adrs,
-                             id + t_times_i, psig))
+                             node_id + tree_id_times_two_power_a, out, sizeof(out))
+                || !WPACKET_memcpy(sig_wpkt, out, n))
             return 0;
-        psig += n;
 
-        for (j = 0; j < a; ++j) {
-            s = id ^ 1;
-            if (!slh_fors_node(ctx, sk_seed, pk_seed, adrs, s + i * (1 << (a - j)),
-                               j, psig))
+        /*
+         * Traverse from the bottom of the tree (layer = 0)
+         * up to the root (layer = a - 1).
+         * TODO - This is a really inefficient way of doing this, since at
+         * layer a - 1 it calculates most of the hashes of the entire tree as
+         * well as all the leaf nodes. So it is calculating nodes multiple times.
+         */
+        for (layer = 0; layer < a; ++layer) {
+            s = node_id ^ 1; /* XOR gets the index of the other child in a binary tree */
+            if (!slh_fors_node(ctx, sk_seed, pk_seed, adrs,
+                               s + tree_offset, layer, out, sizeof(out)))
                 return 0;
-            id >>= 1;
-            psig += n;
+            node_id >>= 1;/* Get the parent node id */
+            tree_offset >>= 1; /* Each layer up has half as many nodes */
+            WPACKET_memcpy(sig_wpkt, out, n);
         }
-        t_times_i += t;
+        tree_id_times_two_power_a += two_power_a;
     }
-    assert((size_t)(psig - sig) == sig_len);
     return 1;
 }
 
@@ -165,8 +192,10 @@ int ossl_slh_fors_sign(SLH_DSA_CTX *ctx, const uint8_t *md,
  * @brief Compute a candidate FORS public key from a message and signature.
  * See FIPS 205 Section 8.4 Algorithm 17.
  *
+ * A FORS signature has a size of (k * (a + 1) * n) bytes
+ *
  * @param ctx Contains SLH_DSA algorithm functions and constants.
- * @param sig A FORS signature of size (k * (a + 1) * n) bytes
+ * @param fors_sig_rpkt A PACKET object to read a FORS signature from
  * @param md A message digest of size (k * a / 8) bytes
  * @param pk_seed A public key seed of size |n|
  * @param adrs The ADRS object must have a layer address of zero, and the
@@ -174,12 +203,14 @@ int ossl_slh_fors_sign(SLH_DSA_CTX *ctx, const uint8_t *md,
  *             the type set to FORS_TREE, and the keypair address set to the
  *             index of the WOTS+ key that signs the FORS key.
  * @param pk_out The returned candidate FORS public key of size |n|
+ * @param pk_out_len The maximum size of |pk_out|
  * @returns 1 on success, or 0 on error.
  */
-int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, const uint8_t *sig,
+int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, PACKET *fors_sig_rpkt,
                               const uint8_t *md, const uint8_t *pk_seed,
-                              SLH_ADRS adrs, uint8_t *pk_out)
+                              SLH_ADRS adrs, uint8_t *pk_out, size_t pk_out_len)
 {
+    int ret = 0;
     SLH_ADRS_DECLARE(pk_adrs);
     SLH_ADRS_FUNC_DECLARE(ctx, adrsf);
     SLH_ADRS_FN_DECLARE(adrsf, set_tree_index);
@@ -189,12 +220,19 @@ int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, const uint8_t *sig,
     SLH_HASH_FN_DECLARE(hashf, H);
     uint32_t i, j, aoff = 0;
     uint32_t ids[SLH_MAX_K];
-    uint8_t roots[SLH_MAX_ROOTS], *node = roots;
     const SLH_DSA_PARAMS *params = ctx->params;
     uint32_t a = params->a;
     uint32_t k = params->k;
     uint32_t n = params->n;
     uint32_t two_power_a = (1 << a);
+    const uint8_t *sk, *authj; /* Pointers to |sig| buffer inside fors_sig_rpkt */
+    uint8_t roots[SLH_MAX_ROOTS];
+    size_t roots_len = 0; /* The size of |roots| */
+    uint8_t *node0, *node1; /* Pointers into roots[] */
+    WPACKET root_pkt, *wroot_pkt = &root_pkt; /* Points to |roots| buffer */
+
+    if (!WPACKET_init_static_len(wroot_pkt, roots, sizeof(roots), 0))
+        return 0;
 
     /* Split md into k a-bit values e.g ids[0..k-1] = 12 bits each of md */
     slh_base_2b(md, a, ids, k);
@@ -206,36 +244,49 @@ int ossl_slh_fors_pk_from_sig(SLH_DSA_CTX *ctx, const uint8_t *sig,
 
         set_tree_height(adrs, 0);
         set_tree_index(adrs, node_id);
-        if (!F(hctx, pk_seed, adrs, sig, n, node))
-            return 0;
-        sig += n;
 
+        /* Regenerate the public key of the leaf */
+        if (!PACKET_get_bytes(fors_sig_rpkt, &sk, n)
+                || !WPACKET_allocate_bytes(wroot_pkt, n, &node0)
+                || !F(hctx, pk_seed, adrs, sk, n, node0, n))
+            goto err;
+
+        /* This omits the copying of the nodes that the FIPS 205 code does */
+        node1 = node0;
         for (j = 0; j < a; ++j) {
+            /* Get this layers other child public key */
+            if (!PACKET_get_bytes(fors_sig_rpkt, &authj, n))
+                goto err;
+            /* Hash the children together to get the parent nodes public key */
             set_tree_height(adrs, j + 1);
             if ((id & 1) == 0) {
                 node_id >>= 1;
                 set_tree_index(adrs, node_id);
-                if (!H(hctx, pk_seed, adrs, node, sig, node))
-                    return 0;
+                if (!H(hctx, pk_seed, adrs, node0, authj, node1, n))
+                    goto err;
             } else {
                 node_id = (node_id - 1) >> 1;
                 set_tree_index(adrs, node_id);
-                if (!H(hctx, pk_seed, adrs, sig, node, node))
-                    return 0;
+                if (!H(hctx, pk_seed, adrs, authj, node0, node1, n))
+                    goto err;
             }
             id >>= 1;
-            sig += n;
         }
         aoff += two_power_a;
-        node += n;
     }
-    assert((size_t)(node - roots) <= sizeof(roots));
+    if (!WPACKET_get_total_written(wroot_pkt, &roots_len))
+        goto err;
 
     /* The public key is the hash of all the roots of the k trees */
     adrsf->copy(pk_adrs, adrs);
     adrsf->set_type_and_clear(pk_adrs, SLH_ADRS_TYPE_FORS_ROOTS);
     adrsf->copy_keypair_address(pk_adrs, adrs);
-    return hashf->T(hctx, pk_seed, pk_adrs, roots, node - roots, pk_out);
+    ret = hashf->T(hctx, pk_seed, pk_adrs, roots, roots_len,
+                   pk_out, pk_out_len);
+ err:
+    if (!WPACKET_finish(wroot_pkt))
+        ret = 0;
+    return ret;
 }
 
 /**
index 9eafb42002dcafdf8888297b806646cc1b582c38..7f240276d5f5488d38cdd5d1c936fa0916f0ad8c 100644 (file)
@@ -160,21 +160,25 @@ static ossl_inline int xof_digest_4(EVP_MD_CTX *ctx,
 static int
 slh_hmsg_shake(SLH_HASH_CTX *hctx, const uint8_t *r, const uint8_t *pk_seed,
                const uint8_t *pk_root, const uint8_t *msg, size_t msg_len,
-               uint8_t *out)
+               uint8_t *out, size_t out_len)
 {
     size_t m = hctx->m;
     size_t n = hctx->n;
 
+    assert(m <= out_len);
+
     return xof_digest_4(hctx->md_ctx, r, n, pk_seed, n, pk_root, n,
                         msg, msg_len, out, m);
 }
 
 static int
 slh_prf_shake(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *sk_seed,
-              const SLH_ADRS adrs, uint8_t *out)
+              const SLH_ADRS adrs, uint8_t *out, size_t out_len)
 {
     size_t n = hctx->n;
 
+    assert(n <= out_len);
+
     return xof_digest_3(hctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE,
                         sk_seed, n, out, n);
 }
@@ -182,40 +186,50 @@ slh_prf_shake(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *sk_seed
 static int
 slh_prf_msg_shake(SLH_HASH_CTX *hctx, const uint8_t *sk_prf,
                   const uint8_t *opt_rand, const uint8_t *msg, size_t msg_len,
-                  uint8_t *out)
+                  WPACKET *pkt)
 {
+    unsigned char out[SLH_MAX_N];
     size_t n = hctx->n;
 
+    assert(n <= sizeof(out));
+
     return xof_digest_3(hctx->md_ctx, sk_prf, n, opt_rand, n,
-                        msg, msg_len, out, n);
+                        msg, msg_len, out, n)
+        && WPACKET_memcpy(pkt, out, n);
 }
 
 static int
 slh_f_shake(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs,
-            const uint8_t *m1, size_t m1_len, uint8_t *out)
+            const uint8_t *m1, size_t m1_len, uint8_t *out, size_t out_len)
 {
     size_t n = hctx->n;
 
+    assert(n <= out_len);
+
     return xof_digest_3(hctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE,
                         m1, m1_len, out, n);
 }
 
 static int
 slh_h_shake(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs,
-            const uint8_t *m1, const uint8_t *m2, uint8_t *out)
+            const uint8_t *m1, const uint8_t *m2, uint8_t *out, size_t out_len)
 {
     size_t n = hctx->n;
 
+    assert(n <= out_len);
+
     return xof_digest_4(hctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE,
                         m1, n, m2, n, out, n);
 }
 
 static int
 slh_t_shake(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs,
-            const uint8_t *ml, size_t ml_len, uint8_t *out)
+            const uint8_t *ml, size_t ml_len, uint8_t *out, size_t out_len)
 {
     size_t n = hctx->n;
 
+    assert(n <= out_len);
+
     return xof_digest_3(hctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE,
                         ml, ml_len, out, n);
 }
@@ -239,13 +253,15 @@ digest_4(EVP_MD_CTX *ctx,
 static int
 slh_hmsg_sha2(SLH_HASH_CTX *hctx, const uint8_t *r, const uint8_t *pk_seed,
               const uint8_t *pk_root, const uint8_t *msg, size_t msg_len,
-              uint8_t *out)
+              uint8_t *out, size_t out_len)
 {
+    size_t m = hctx->m;
     size_t n = hctx->n;
     uint8_t seed[2 * SLH_MAX_N + MAX_DIGEST_SIZE];
     int sz = EVP_MD_get_size(hctx->md);
     size_t seed_len = (size_t)sz + 2 * n;
 
+    assert(m <= out_len);
     assert(sz > 0);
     assert(seed_len <= sizeof(seed));
 
@@ -253,13 +269,13 @@ slh_hmsg_sha2(SLH_HASH_CTX *hctx, const uint8_t *r, const uint8_t *pk_seed,
     memcpy(seed + n, pk_seed, n);
     return digest_4(hctx->md_big_ctx, r, n, pk_seed, n, pk_root, n, msg, msg_len,
                     seed + 2 * n)
-        && (PKCS1_MGF1(out, hctx->m, seed, seed_len, hctx->md) == 0);
+        && (PKCS1_MGF1(out, m, seed, seed_len, hctx->md) == 0);
 }
 
 static int
 slh_prf_msg_sha2(SLH_HASH_CTX *hctx,
                  const uint8_t *sk_prf, const uint8_t *opt_rand,
-                 const uint8_t *msg, size_t msg_len, uint8_t *out)
+                 const uint8_t *msg, size_t msg_len, WPACKET *pkt)
 {
     int ret;
     EVP_MAC_CTX *mctx = hctx->hmac_ctx;
@@ -288,19 +304,20 @@ slh_prf_msg_sha2(SLH_HASH_CTX *hctx,
     ret = EVP_MAC_init(mctx, sk_prf, n, p) == 1
         && EVP_MAC_update(mctx, opt_rand, n) == 1
         && EVP_MAC_update(mctx, msg, msg_len) == 1
-        && EVP_MAC_final(mctx, mac, NULL, sizeof(mac)) == 1;
-    memcpy(out, mac, n); /* Truncate output to n bytes */
+        && EVP_MAC_final(mctx, mac, NULL, sizeof(mac)) == 1
+        && WPACKET_memcpy(pkt, mac, n); /* Truncate output to n bytes */
     return ret;
 }
 
 static ossl_inline int
 do_hash(EVP_MD_CTX *ctx, size_t n, const uint8_t *pk_seed, const SLH_ADRS adrs,
-        const uint8_t *m, size_t m_len, size_t b, uint8_t *out)
+        const uint8_t *m, size_t m_len, size_t b, uint8_t *out, size_t out_len)
 {
     int ret;
     uint8_t zeros[128] = { 0 };
     uint8_t digest[MAX_DIGEST_SIZE];
 
+    assert(n <= out_len);
     assert(b - n < sizeof(zeros));
 
     ret = digest_4(ctx, pk_seed, n, zeros, b - n, adrs, SLH_ADRSC_SIZE,
@@ -312,25 +329,26 @@ do_hash(EVP_MD_CTX *ctx, size_t n, const uint8_t *pk_seed, const SLH_ADRS adrs,
 
 static int
 slh_prf_sha2(SLH_HASH_CTX *hctx, const uint8_t *pk_seed,
-             const uint8_t *sk_seed, const SLH_ADRS adrs, uint8_t *out)
+             const uint8_t *sk_seed, const SLH_ADRS adrs,
+             uint8_t *out, size_t out_len)
 {
     size_t n = hctx->n;
 
     return do_hash(hctx->md_ctx, n, pk_seed, adrs, sk_seed, n,
-                   SHA2_NUM_ZEROS_BOUND1, out);
+                   SHA2_NUM_ZEROS_BOUND1, out, out_len);
 }
 
 static int
 slh_f_sha2(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs,
-           const uint8_t *m1, size_t m1_len, uint8_t *out)
+           const uint8_t *m1, size_t m1_len, uint8_t *out, size_t out_len)
 {
     return do_hash(hctx->md_ctx, hctx->n, pk_seed, adrs, m1, m1_len,
-                   SHA2_NUM_ZEROS_BOUND1, out);
+                   SHA2_NUM_ZEROS_BOUND1, out, out_len);
 }
 
 static int
 slh_h_sha2(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs,
-           const uint8_t *m1, const uint8_t *m2, uint8_t *out)
+           const uint8_t *m1, const uint8_t *m2, uint8_t *out, size_t out_len)
 {
     uint8_t m[SLH_MAX_N * 2];
     size_t n = hctx->n;
@@ -338,15 +356,15 @@ slh_h_sha2(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs,
     memcpy(m, m1, n);
     memcpy(m + n, m2, n);
     return do_hash(hctx->md_big_ctx, n, pk_seed, adrs, m, 2 * n,
-                   hctx->sha2_h_and_t_bound, out);
+                   hctx->sha2_h_and_t_bound, out, out_len);
 }
 
 static int
 slh_t_sha2(SLH_HASH_CTX *hctx, const uint8_t *pk_seed, const SLH_ADRS adrs,
-           const uint8_t *ml, size_t ml_len, uint8_t *out)
+           const uint8_t *ml, size_t ml_len, uint8_t *out, size_t out_len)
 {
     return do_hash(hctx->md_big_ctx, hctx->n, pk_seed, adrs, ml, ml_len,
-                   hctx->sha2_h_and_t_bound, out);
+                   hctx->sha2_h_and_t_bound, out, out_len);
 }
 
 const SLH_HASH_FUNC *ossl_slh_get_hash_fn(int is_shake)
index b6e3de77dba13c974dbdef65d1fd475fe8cb33d4..a48b7f2ff4d5ce58b30d88ce486e2bc10af1a1db 100644 (file)
@@ -13,6 +13,7 @@
 
 # include <openssl/e_os2.h>
 # include "slh_adrs.h"
+# include "internal/packet.h"
 
 # define SLH_HASH_FUNC_DECLARE(ctx, hashf, hashctx)   \
     const SLH_HASH_FUNC *hashf = ctx->hash_func;      \
@@ -40,22 +41,26 @@ typedef struct slh_hash_ctx_st {
  */
 typedef int (OSSL_SLH_HASHFUNC_H_MSG)(SLH_HASH_CTX *ctx, const uint8_t *r,
     const uint8_t *pk_seed, const uint8_t *pk_root,
-    const uint8_t *msg, size_t msg_len, uint8_t *out);
+    const uint8_t *msg, size_t msg_len, uint8_t *out, size_t out_len);
 
 typedef int (OSSL_SLH_HASHFUNC_PRF)(SLH_HASH_CTX *ctx, const uint8_t *pk_seed,
-    const uint8_t *sk_seed, const SLH_ADRS adrs, uint8_t *out);
+    const uint8_t *sk_seed, const SLH_ADRS adrs,
+    uint8_t *out, size_t out_len);
 
 typedef int (OSSL_SLH_HASHFUNC_PRF_MSG)(SLH_HASH_CTX *ctx, const uint8_t *sk_prf,
-    const uint8_t *opt_rand, const uint8_t *msg, size_t msg_len, uint8_t *out);
+    const uint8_t *opt_rand, const uint8_t *msg, size_t msg_len, WPACKET *pkt);
 
 typedef int (OSSL_SLH_HASHFUNC_F)(SLH_HASH_CTX *ctx, const uint8_t *pk_seed,
-    const SLH_ADRS adrs, const uint8_t *m1, size_t m1_len, uint8_t *out);
+    const SLH_ADRS adrs, const uint8_t *m1, size_t m1_len,
+    uint8_t *out, size_t out_len);
 
 typedef int (OSSL_SLH_HASHFUNC_H)(SLH_HASH_CTX *ctx, const uint8_t *pk_seed,
-    const SLH_ADRS adrs, const uint8_t *m1, const uint8_t *m2, uint8_t *out);
+    const SLH_ADRS adrs, const uint8_t *m1, const uint8_t *m2,
+    uint8_t *out, size_t out_len);
 
 typedef int (OSSL_SLH_HASHFUNC_T)(SLH_HASH_CTX *ctx, const uint8_t *pk_seed,
-    const SLH_ADRS adrs, const uint8_t *m1, size_t m1_len, uint8_t *out);
+    const SLH_ADRS adrs, const uint8_t *m1, size_t m1_len,
+    uint8_t *out, size_t out_len);
 
 typedef struct slh_hash_func_st {
     OSSL_SLH_HASHFUNC_H_MSG *H_MSG;
index fa16d45b77ff69b7b4e4de622e7e0434cbe3d0f4..e916fa8eff1b182320b621756103a7a57de6ef6d 100644 (file)
 #include <string.h>
 #include "slh_dsa_local.h"
 
-#define SLH_XMSS_SIG_LEN(n, hm) ((SLH_WOTS_LEN(n) + (hm)) * (n))
-
 /**
  * @brief Generate a Hypertree Signature
  * See FIPS 205 Section 7.1 Algorithm 12
  *
+ * This writes |d| XMSS signatures i.e. ((|h| + |d| * |len|) * |n|)
+ *
  * @param ctx Contains SLH_DSA algorithm functions and constants.
  * @param msg A message of size |n|.
  * @param sk_seed The private key seed of size |n|
  * @param pk_seed The public key seed of size |n|
  * @param tree_id Index of the XMSS tree that will sign the message
  * @param leaf_id Index of the WOTS+ key within the XMSS tree that will signed the message
- * @param sig The returned Hypertree Signature (which is |d| XMSS signatures)
- * @param sig_len The size of |sig| which is (|h| + |d| * |len|) * |n|)
+ * @param sig_wpkt A WPACKET object to write the Hypertree Signature to.
  * @returns 1 on success, or 0 on error.
  */
 int ossl_slh_ht_sign(SLH_DSA_CTX *ctx,
                      const uint8_t *msg, const uint8_t *sk_seed,
                      const uint8_t *pk_seed,
-                     uint64_t tree_id, uint32_t leaf_id,
-                     uint8_t *sig, size_t sig_len)
+                     uint64_t tree_id, uint32_t leaf_id, WPACKET *sig_wpkt)
 {
     SLH_ADRS_FUNC_DECLARE(ctx, adrsf);
     SLH_ADRS_DECLARE(adrs);
@@ -40,8 +38,8 @@ int ossl_slh_ht_sign(SLH_DSA_CTX *ctx,
     uint32_t n = ctx->params->n;
     uint32_t d = ctx->params->d;
     uint32_t hm = ctx->params->hm;
-    uint8_t *psig = sig;
-    size_t xmss_sig_len = SLH_XMSS_SIG_LEN(n, hm);
+    uint8_t *psig;
+    PACKET rpkt, *xmss_sig_rpkt = &rpkt;
 
     mask = (1 << hm) - 1; /* A mod 2^h = A & ((2^h - 1))) */
 
@@ -49,21 +47,23 @@ int ossl_slh_ht_sign(SLH_DSA_CTX *ctx,
     memcpy(root, msg, n);
 
     for (layer = 0; layer < d; ++layer) {
+        /* type = SLH_ADRS_TYPE_WOTS_HASH */
         adrsf->set_layer_address(adrs, layer);
         adrsf->set_tree_address(adrs, tree_id);
+        psig = WPACKET_get_curr(sig_wpkt);
         if (!ossl_slh_xmss_sign(ctx, root, sk_seed, leaf_id, pk_seed, adrs,
-                                psig, xmss_sig_len))
+                                sig_wpkt))
+            return 0;
+        if (!PACKET_buf_init(xmss_sig_rpkt, psig,  WPACKET_get_curr(sig_wpkt) - psig))
             return 0;
         if (layer < d - 1) {
-            if (!ossl_slh_xmss_pk_from_sig(ctx, leaf_id, psig, root,
-                                           pk_seed, adrs, root))
+            if (!ossl_slh_xmss_pk_from_sig(ctx, leaf_id, xmss_sig_rpkt, root,
+                                           pk_seed, adrs, root, sizeof(root)))
                 return 0;
         }
-        psig += xmss_sig_len;
         leaf_id = tree_id & mask;
         tree_id >>= hm;
     }
-    assert((size_t)(psig - sig) == sig_len);
     return 1;
 }
 
@@ -81,21 +81,19 @@ int ossl_slh_ht_sign(SLH_DSA_CTX *ctx,
  *
  * @returns 1 if the computed XMSS public key matches pk_root, or 0 otherwise.
  */
-int ossl_slh_ht_verify(SLH_DSA_CTX *ctx, const uint8_t *msg, const uint8_t *sig,
+int ossl_slh_ht_verify(SLH_DSA_CTX *ctx, const uint8_t *msg, PACKET *sig_pkt,
                        const uint8_t *pk_seed, uint64_t tree_id, uint32_t leaf_id,
                        const uint8_t *pk_root)
 {
     SLH_ADRS_FUNC_DECLARE(ctx, adrsf);
     SLH_ADRS_DECLARE(adrs);
     uint8_t node[SLH_MAX_N];
-    uint32_t layer, len, mask, d, n, tree_height;
     const SLH_DSA_PARAMS *params = ctx->params;
-
-    tree_height = params->hm;
-    n = params->n;
-    d = params->d;
-    len = SLH_XMSS_SIG_LEN(n, tree_height);
-    mask = (1 << tree_height) - 1;
+    uint32_t tree_height = params->hm;
+    uint32_t n = params->n;
+    uint32_t d = params->d;
+    uint32_t mask = (1 << tree_height) - 1;
+    uint32_t layer;
 
     adrsf->zero(adrs);
     memcpy(node, msg, n);
@@ -103,10 +101,9 @@ int ossl_slh_ht_verify(SLH_DSA_CTX *ctx, const uint8_t *msg, const uint8_t *sig,
     for (layer = 0; layer < d; ++layer) {
         adrsf->set_layer_address(adrs, layer);
         adrsf->set_tree_address(adrs, tree_id);
-        if (!ossl_slh_xmss_pk_from_sig(ctx, leaf_id, sig, node,
-                                       pk_seed, adrs, node))
+        if (!ossl_slh_xmss_pk_from_sig(ctx, leaf_id, sig_pkt, node,
+                                       pk_seed, adrs, node, sizeof(node)))
             return 0;
-        sig += len;
         leaf_id = tree_id & mask;
         tree_id >>= tree_height;
     }
index a0b2766745c1f43213404c6d30b8a28b88a39582..7f56dd457547f3f925f6736f9f1a54e6c57db39f 100644 (file)
@@ -82,32 +82,43 @@ static ossl_inline void compute_checksum_nibbles(const uint8_t *in, size_t in_le
  *
  * @param ctx Contains SLH_DSA algorithm functions and constants.
  * @param in An input string of |n| bytes
- * @param n The size of |in| and |pk_seed|_
  * @param start_index The chaining start index
  * @param steps The number of iterations starting from |start_index|
  *              Note |start_index| + |steps| < w
  *              (where w = 16 indicates the length of the hash chains)
+ * @param pk_seed A public key seed (which is added to the hash)
  * @param adrs An ADRS object which has a type of WOTS_HASH, and has a layer
  *             address, tree address, key pair address and chain address
- * @param pk_seed A public key seed (which is added to the hash)
+ * @params wpkt A WPACKET object to write the hash chain to (n bytes are written)
  * @returns 1 on success, or 0 on error.
  */
 static int slh_wots_chain(SLH_DSA_CTX *ctx, const uint8_t *in,
                           uint8_t start_index, uint8_t steps,
-                          const uint8_t *pk_seed, uint8_t *adrs, uint8_t *out)
+                          const uint8_t *pk_seed, uint8_t *adrs, WPACKET *wpkt)
 {
     SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx);
     SLH_ADRS_FUNC_DECLARE(ctx, adrsf);
     SLH_HASH_FN_DECLARE(hashf, F);
     SLH_ADRS_FN_DECLARE(adrsf, set_hash_address);
-    size_t j, end_index = start_index + steps;
+    size_t j = start_index, end_index;
     size_t n = ctx->params->n;
+    uint8_t *tmp; /* Pointer into the |wpkt| buffer */
+    size_t tmp_len = n;
+
+    if (steps == 0)
+        return WPACKET_memcpy(wpkt, in, n);
+
+    if (!WPACKET_allocate_bytes(wpkt, tmp_len, &tmp))
+        return 0;
 
-    memcpy(out, in, n);
+    set_hash_address(adrs, j++);
+    if (!F(hctx, pk_seed, adrs, in, n, tmp, tmp_len))
+        return 0;
 
-    for (j = start_index; j < end_index; ++j) {
+    end_index = start_index + steps;
+    for (; j < end_index; ++j) {
         set_hash_address(adrs, j);
-        if (!F(hctx, pk_seed, adrs, out, n, out))
+        if (!F(hctx, pk_seed, adrs, tmp, n, tmp, tmp_len))
             return 0;
     }
     return 1;
@@ -123,11 +134,12 @@ static int slh_wots_chain(SLH_DSA_CTX *ctx, const uint8_t *in,
  * @param adrs An ADRS object containing the layer address, tree address and
  *             keypair address of the WOTS+ public key to generate.
  * @param pk_out The generated public key of size |n|
+ * @param pk_out_len The maximum size of |pk_out|
  * @returns 1 on success, or 0 on error.
  */
 int ossl_slh_wots_pk_gen(SLH_DSA_CTX *ctx,
                          const uint8_t *sk_seed, const uint8_t *pk_seed,
-                         SLH_ADRS adrs, uint8_t *pk_out)
+                         SLH_ADRS adrs, uint8_t *pk_out, size_t pk_out_len)
 {
     int ret = 0;
     SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx);
@@ -136,33 +148,38 @@ int ossl_slh_wots_pk_gen(SLH_DSA_CTX *ctx,
     SLH_ADRS_FN_DECLARE(adrsf, set_chain_address);
     SLH_ADRS_DECLARE(sk_adrs);
     SLH_ADRS_DECLARE(wots_pk_adrs);
-    size_t i, len = 0;
     size_t n = ctx->params->n;
-    uint8_t tmp[SLH_WOTS_LEN_MAX * SLH_MAX_N], *ptmp = tmp;
-    uint8_t sk[32];
+    size_t i, len = SLH_WOTS_LEN(n); /*2 * n + 3 */
+    uint8_t sk[SLH_MAX_N];
+    uint8_t tmp[SLH_WOTS_LEN_MAX * SLH_MAX_N];
+    WPACKET pkt, *tmp_wpkt = &pkt; /* Points to the |tmp| buffer */
+    size_t tmp_len = 0;
 
+    if (!WPACKET_init_static_len(tmp_wpkt, tmp, sizeof(tmp), 0))
+        return 0;
     adrsf->copy(sk_adrs, adrs);
     adrsf->set_type_and_clear(sk_adrs, SLH_ADRS_TYPE_WOTS_PRF);
     adrsf->copy_keypair_address(sk_adrs, adrs);
 
-    len = SLH_WOTS_LEN(n); /* See Section 5 intro */
-    for (i = 0; i < len; ++i) {
+    for (i = 0; i < len; ++i) { /* len = 2n + 3 */
         set_chain_address(sk_adrs, i);
-        if (!PRF(hctx, pk_seed, sk_seed, sk_adrs, sk))
+        if (!PRF(hctx, pk_seed, sk_seed, sk_adrs, sk, sizeof(sk)))
             goto end;
 
         set_chain_address(adrs, i);
-        if (!slh_wots_chain(ctx, sk, 0, NIBBLE_MASK, pk_seed, adrs, ptmp))
+        if (!slh_wots_chain(ctx, sk, 0, NIBBLE_MASK, pk_seed, adrs, tmp_wpkt))
             goto end;
-        ptmp += n;
     }
 
-    len = ptmp - tmp; /* should be n * (2 * n + 3) */
+    if (!WPACKET_get_total_written(tmp_wpkt, &tmp_len)) /* should be n * (2 * n + 3) */
+        goto end;
     adrsf->copy(wots_pk_adrs, adrs);
     adrsf->set_type_and_clear(wots_pk_adrs, SLH_ADRS_TYPE_WOTS_PK);
     adrsf->copy_keypair_address(wots_pk_adrs, adrs);
-    ret = hashf->T(hctx, pk_seed, wots_pk_adrs, tmp, len, pk_out);
+    ret = hashf->T(hctx, pk_seed, wots_pk_adrs, tmp, tmp_len,
+                   pk_out, pk_out_len);
 end:
+    WPACKET_finish(tmp_wpkt);
     OPENSSL_cleanse(tmp, sizeof(tmp));
     OPENSSL_cleanse(sk, n);
     return ret;
@@ -172,6 +189,8 @@ end:
  * @brief WOTS+ Signature generation
  * See FIPS 205 Section 5.2 Algorithm 7
  *
+ * The returned signature size is len * |n| bytes (where len = 2 * |n| + 3).
+ *
  * @param ctx Contains SLH_DSA algorithm functions and constants.
  * @param msg An input message of size |n| bytes.
  *            The message is either an XMSS or FORS public key
@@ -179,14 +198,12 @@ end:
  * @param pk_seed The public key seed  of size |n| bytes
  * @param adrs An address containing the layer address, tree address and key
  *             pair address. The size is either 32 or 22 bytes.
- * @param sig The returned signature.
- * @param sig_len The size of |sig| which should be len * |n| bytes.
- *                (where len = 2 * |n| + 3)
+ * @param sig_wpkt A WPACKET object to write the signature to.
  * @returns 1 on success, or 0 on error.
  */
 int ossl_slh_wots_sign(SLH_DSA_CTX *ctx, const uint8_t *msg,
                        const uint8_t *sk_seed, const uint8_t *pk_seed,
-                       SLH_ADRS adrs, uint8_t *sig, size_t sig_len)
+                       SLH_ADRS adrs, WPACKET *sig_wpkt)
 {
     int ret = 0;
     SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx);
@@ -198,7 +215,6 @@ int ossl_slh_wots_sign(SLH_DSA_CTX *ctx, const uint8_t *msg,
     uint8_t sk[SLH_MAX_N];
     size_t i, len1, len;
     size_t n = ctx->params->n;
-    uint8_t *psig = sig;
 
     len1 = SLH_WOTS_LEN1(n); /* 2 * n is for the message length in nibbles */
     len = len1 + SLH_WOTS_LEN2;  /* 2 * n + 3 (3 checksum nibbles) */
@@ -218,19 +234,16 @@ int ossl_slh_wots_sign(SLH_DSA_CTX *ctx, const uint8_t *msg,
     for (i = 0; i < len; ++i) {
         set_chain_address(sk_adrs, i);
         /* compute chain i secret */
-        if (!PRF(hctx, pk_seed, sk_seed, sk_adrs, sk))
+        if (!PRF(hctx, pk_seed, sk_seed, sk_adrs, sk, sizeof(sk)))
             goto err;
         set_chain_address(adrs, i);
         /* compute chain i signature */
         if (!slh_wots_chain(ctx, sk, 0, msg_and_csum_nibbles[i],
-                            pk_seed, adrs, psig))
+                            pk_seed, adrs, sig_wpkt))
             goto err;
-        psig += n;
     }
-    assert(sig_len == (size_t)(psig - sig));
     ret = 1;
 err:
-    OPENSSL_cleanse(sk, n);
     return ret;
 }
 
@@ -238,30 +251,40 @@ err:
  * @brief Compute a candidate WOTS+ public key from a message and signature
  * See FIPS 205 Section 5.3 Algorithm 8
  *
+ * The size of the signature is len * |n| bytes (where len = 2 * |n| + 3).
+ *
  * @param ctx Contains SLH_DSA algorithm functions and constants.
- * @param sig A WOTS+ signature of size len * |n| bytes. (where len = 2 * |n| + 3)
+ * @param sig_rpkt A PACKET object to read a WOTS+ signature from
  * @param msg A message of size |n| bytes.
  * @param pk_seed The public key seed of size |n|.
  * @param adrs An ADRS object containing the layer address, tree address and
  *             key pair address that of the WOTS+ key used to sign the message.
  * @param pk_out The returned public key candidate of size |n|
+ * @param pk_out_len The maximum size of |pk_out|
  * @returns 1 on success, or 0 on error.
  */
 int ossl_slh_wots_pk_from_sig(SLH_DSA_CTX *ctx,
-                              const uint8_t *sig, const uint8_t *msg,
+                              PACKET *sig_rpkt, const uint8_t *msg,
                               const uint8_t *pk_seed, uint8_t *adrs,
-                              uint8_t *pk_out)
+                              uint8_t *pk_out, size_t pk_out_len)
 {
+    int ret = 0;
     SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx);
     SLH_ADRS_FUNC_DECLARE(ctx, adrsf);
     SLH_ADRS_FN_DECLARE(adrsf, set_chain_address);
     SLH_ADRS_DECLARE(wots_pk_adrs);
     uint8_t msg_and_csum_nibbles[SLH_WOTS_LEN_MAX];
-    uint8_t tmp[SLH_WOTS_LEN_MAX * SLH_MAX_N], *ptmp = tmp;
     size_t i, len1, len, n = ctx->params->n;
+    const uint8_t *sig_i;  /* Pointer into |pkt_sig| buffer */
+    uint8_t tmp[SLH_WOTS_LEN_MAX * SLH_MAX_N];
+    WPACKET pkt, *tmp_pkt = &pkt;
+    size_t tmp_len = 0;
 
     len1 = SLH_WOTS_LEN1(n);
-    len = len1 + SLH_WOTS_LEN2;
+    len = len1 + SLH_WOTS_LEN2; /* 2n + 3 */
+
+    if (!WPACKET_init_static_len(tmp_pkt, tmp, sizeof(tmp), 0))
+        return 0;
 
     slh_bytes_to_nibbles(msg, n, msg_and_csum_nibbles);
     compute_checksum_nibbles(msg_and_csum_nibbles, len1, msg_and_csum_nibbles + len1);
@@ -269,16 +292,22 @@ int ossl_slh_wots_pk_from_sig(SLH_DSA_CTX *ctx,
     /* Compute the end nodes for each of the chains */
     for (i = 0; i < len; ++i) {
         set_chain_address(adrs, i);
-        if (!slh_wots_chain(ctx, sig, msg_and_csum_nibbles[i],
-                            NIBBLE_MASK - msg_and_csum_nibbles[i],
-                            pk_seed, adrs, ptmp))
-            return 0;
-        sig += n;
-        ptmp += n;
+        if (!PACKET_get_bytes(sig_rpkt, &sig_i, n)
+                || !slh_wots_chain(ctx, sig_i, msg_and_csum_nibbles[i],
+                                   NIBBLE_MASK - msg_and_csum_nibbles[i],
+                                   pk_seed, adrs, tmp_pkt))
+            goto err;
     }
     /* compress the computed public key value */
     adrsf->copy(wots_pk_adrs, adrs);
     adrsf->set_type_and_clear(wots_pk_adrs, SLH_ADRS_TYPE_WOTS_PK);
     adrsf->copy_keypair_address(wots_pk_adrs, adrs);
-    return hashf->T(hctx, pk_seed, wots_pk_adrs, tmp, ptmp - tmp, pk_out);
+    if (!WPACKET_get_total_written(tmp_pkt, &tmp_len))
+        goto err;
+    ret = hashf->T(hctx, pk_seed, wots_pk_adrs, tmp, tmp_len,
+                    pk_out, pk_out_len);
+ err:
+    if (!WPACKET_finish(tmp_pkt))
+        ret = 0;
+    return ret;
 }
index 08f112f0d2f2dff8ea40c5707fa575552010b738..eb9125aec33efba9779f86afae7532d8435cb243 100644 (file)
  * the hash of each parent using 2 child nodes.
  *
  * @param ctx Contains SLH_DSA algorithm functions and constants.
- * @param sk_seed A private key seed of size |n|
+ * @param sk_seed A SLH-DSA private key seed of size |n|
  * @param nodeid The index of the target node being computed
  *               (which must be < 2^(hm - height)
  * @param h The height within the tree of the node being computed.
  *          (which must be <= hm) (hm is one of 3, 4, 8 or 9)
  *          At height=0 There are 2^hm leaf nodes,
  *          and the root node is at height = hm)
- * @param pk_seed A public key seed of size |n|
+ * @param pk_seed A SLH-DSA public key seed of size |n|
  * @param adrs An ADRS object containing the layer address and tree address set
  *             to the XMSS tree within which the XMSS tree is being computed.
  * @param pk_out The generated public key of size |n|
+ * @param pk_out_len The maximum size of |pk_out|
  * @returns 1 on success, or 0 on error.
  */
 int ossl_slh_xmss_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
                         uint32_t node_id, uint32_t h,
-                        const uint8_t *pk_seed, SLH_ADRS adrs, uint8_t *pk_out)
+                        const uint8_t *pk_seed, SLH_ADRS adrs,
+                        uint8_t *pk_out, size_t pk_out_len)
 {
     SLH_ADRS_FUNC_DECLARE(ctx, adrsf);
 
@@ -41,20 +43,22 @@ int ossl_slh_xmss_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
         /* For leaf nodes generate the public key */
         adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_WOTS_HASH);
         adrsf->set_keypair_address(adrs, node_id);
-        if (!ossl_slh_wots_pk_gen(ctx, sk_seed, pk_seed, adrs, pk_out))
+        if (!ossl_slh_wots_pk_gen(ctx, sk_seed, pk_seed, adrs,
+                                  pk_out, pk_out_len))
             return 0;
     } else {
         uint8_t lnode[SLH_MAX_N], rnode[SLH_MAX_N];
 
         if (!ossl_slh_xmss_node(ctx, sk_seed, 2 * node_id, h - 1, pk_seed, adrs,
-                                lnode)
+                                lnode, sizeof(lnode))
                 || !ossl_slh_xmss_node(ctx, sk_seed, 2 * node_id + 1, h - 1,
-                                       pk_seed, adrs, rnode))
+                                       pk_seed, adrs, rnode, sizeof(rnode)))
             return 0;
         adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_TREE);
         adrsf->set_tree_height(adrs, h);
         adrsf->set_tree_index(adrs, node_id);
-        if (!ctx->hash_func->H(&ctx->hash_ctx, pk_seed, adrs, lnode, rnode, pk_out))
+        if (!ctx->hash_func->H(&ctx->hash_ctx, pk_seed, adrs, lnode, rnode,
+                               pk_out, pk_out_len))
             return 0;
     }
     return 1;
@@ -64,6 +68,10 @@ int ossl_slh_xmss_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
  * @brief Generate an XMSS signature using a message and key.
  * See FIPS 205 Section 6.2 Algorithm 10
  *
+ * The generated signature consists of:
+ *  - A WOTS+ signature of size (2 * n + 3) * n
+ *  - An array of authentication paths of size (XMSS tree_height) * n.
+ *
  * @param ctx Contains SLH_DSA algorithm functions and constants.
  * @param msg A message of size |n| bytes to sign
  * @param sk_seed A private key seed of size |n|
@@ -71,56 +79,66 @@ int ossl_slh_xmss_node(SLH_DSA_CTX *ctx, const uint8_t *sk_seed,
  * @param pk_seed A public key seed f size |n|
  * @param adrs An ADRS object containing the layer address and tree address set
  *              to the XMSS key being used to sign the message.
- * @param sig The generated XMSS signature.
- * @param sig_len The size of |sig|. which consists of a WOTS+
- *                signature of size [2 * n + 3][n] followed by an authentication
- *                path of size [tree_height[n].
+ * @param sig_wpkt A WPACKET object to write the generated XMSS signature to.
  * @returns 1 on success, or 0 on error.
  */
 int ossl_slh_xmss_sign(SLH_DSA_CTX *ctx, const uint8_t *msg,
                        const uint8_t *sk_seed, uint32_t node_id,
-                       const uint8_t *pk_seed, SLH_ADRS adrs,
-                       uint8_t *sig, size_t sig_len)
+                       const uint8_t *pk_seed, SLH_ADRS adrs, WPACKET *sig_wpkt)
 {
     SLH_ADRS_FUNC_DECLARE(ctx, adrsf);
-    uint32_t h, id = node_id;
+    SLH_ADRS_DECLARE(tmp_adrs);
     size_t n = ctx->params->n;
-    uint32_t hm = ctx->params->hm;
-    size_t wots_sig_len = n * SLH_WOTS_LEN(n);
-    uint8_t *auth_path = sig + wots_sig_len;
+    uint32_t h, hm = ctx->params->hm;
+    uint32_t id = node_id;
+    uint8_t *auth_path; /* Pointer to a buffer offset inside |sig_wpkt| */
+    size_t auth_path_len = n;
 
+    /*
+     * This code reverses the order of the FIPS 205 code so that it does the
+     * sign first. This simplifies the WPACKET writing.
+     */
+    adrsf->copy(tmp_adrs, adrs);
+    adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_WOTS_HASH);
+    adrsf->set_keypair_address(adrs, node_id);
+    if (!ossl_slh_wots_sign(ctx, msg, sk_seed, pk_seed, adrs, sig_wpkt))
+        return 0;
+
+    adrsf->copy(adrs, tmp_adrs);
     for (h = 0; h < hm; ++h) {
-        if (!ossl_slh_xmss_node(ctx, sk_seed, id ^ 1, h, pk_seed, adrs, auth_path))
+        if (!WPACKET_allocate_bytes(sig_wpkt, auth_path_len, &auth_path)
+                || !ossl_slh_xmss_node(ctx, sk_seed, id ^ 1, h, pk_seed, adrs,
+                                       auth_path, auth_path_len))
             return 0;
         id >>= 1;
-        auth_path += n;
     }
-    adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_WOTS_HASH);
-    adrsf->set_keypair_address(adrs, node_id);
-    return ossl_slh_wots_sign(ctx, msg, sk_seed, pk_seed, adrs, sig, wots_sig_len);
+    return 1;
 }
 
 /**
  * @brief Compute a candidate XMSS public key from a message and XMSS signature
  * See FIPS 205 Section 6.3 Algorithm 11
  *
+ * * The signature consists of:
+ *  - A WOTS+ signature of size (2 * n + 3) * n
+ *  - An array of authentication paths of size (XMSS tree height) * n.
+ *
  * @param ctx Contains SLH_DSA algorithm functions and constants.
  * @param node_id Must be set to the |node_id| used in xmss_sign().
- * @param sig A XMSS signature which consists of a WOTS+ signature of
- *            [2 * n + 3][n] bytes followed by an authentication path of
- *            [hm][n] bytes (where hm is the height of the XMSS tree).
+ * @param sig_rpkt A Packet to read a XMSS signature from.
  * @param msg A message of size |n| bytes
  * @param sk_seed A private key seed of size |n|
  * @param pk_seed A public key seed of size |n|
  * @param adrs An ADRS object containing a layer address and tree address of an
  *             XMSS key used for signing the message.
  * @param pk_out The returned candidate XMSS public key of size |n|.
+ * @param pk_out_len The maximum size of |pk_out|.
  * @returns 1 on success, or 0 on error.
  */
 int ossl_slh_xmss_pk_from_sig(SLH_DSA_CTX *ctx, uint32_t node_id,
-                              const uint8_t *sig, const uint8_t *msg,
+                              PACKET *sig_rpkt, const uint8_t *msg,
                               const uint8_t *pk_seed, SLH_ADRS adrs,
-                              uint8_t *pk_out)
+                              uint8_t *pk_out, size_t pk_out_len)
 {
     SLH_HASH_FUNC_DECLARE(ctx, hashf, hctx);
     SLH_HASH_FN_DECLARE(hashf, H);
@@ -130,31 +148,32 @@ int ossl_slh_xmss_pk_from_sig(SLH_DSA_CTX *ctx, uint32_t node_id,
     uint32_t k;
     size_t n = ctx->params->n;
     uint32_t hm = ctx->params->hm;
-    size_t wots_sig_len = n * SLH_WOTS_LEN(n);
-    const uint8_t *auth_path = sig + wots_sig_len;
     uint8_t *node = pk_out;
+    const uint8_t *auth_path; /* Pointer to buffer offset in |pkt_sig| */
 
     adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_WOTS_HASH);
     adrsf->set_keypair_address(adrs, node_id);
-    if (!ossl_slh_wots_pk_from_sig(ctx, sig, msg, pk_seed, adrs, node))
+    if (!ossl_slh_wots_pk_from_sig(ctx, sig_rpkt, msg, pk_seed, adrs,
+                                   node, pk_out_len))
         return 0;
 
     adrsf->set_type_and_clear(adrs, SLH_ADRS_TYPE_TREE);
 
     for (k = 0; k < hm; ++k) {
+        if (!PACKET_get_bytes(sig_rpkt, &auth_path, n))
+            return 0;
         set_tree_height(adrs, k + 1);
         if ((node_id & 1) == 0) { /* even */
             node_id >>= 1;
             set_tree_index(adrs, node_id);
-            if (!H(hctx, pk_seed, adrs, node, auth_path, node))
+            if (!H(hctx, pk_seed, adrs, node, auth_path, node, pk_out_len))
                 return 0;
         } else { /* odd */
             node_id = (node_id - 1) >> 1;
             set_tree_index(adrs, node_id);
-            if (!H(hctx, pk_seed, adrs, auth_path, node, node))
+            if (!H(hctx, pk_seed, adrs, auth_path, node, node, pk_out_len))
                 return 0;
         }
-        auth_path += n;
     }
     return 1;
 }
index b1c5d014f74c9a299029c87104d5e77c5a43cadb..4cca1f2b5726914492f83f0ab127734404a7ae11 100644 (file)
@@ -101,11 +101,12 @@ EVP_PKEY_verify_message_init(), EVP_PKEY_verify().
 Buffers
 -------
 
-Many functions need to pass around key elements and return signature buffers of
-various sizes which are often updated in loops in parts, all of these sizes
-are known quantities. Currently there is no attempt to use wpacket to pass
-around these sizes. asserts are currently done by the child functions to check
-that the expected size does not exceed the size passed in by the parent.
+There are many functions pass buffers of size |n| Where n is one of 16,24,32
+depending on the algorithm name. These are used for key elements and hashes, so
+PACKETS are not used for these.
+
+Where it makes sense to, WPACKET is used for output (such as signature generation)
+and PACKET for reading signature data.
 
 Constant Time Considerations
 ----------------------------
index 1926b637a5ae954468f6684544f67f844b0c18f4..ef026e2d35081ce25e9360597d1dfe6664beb1b3 100644 (file)
@@ -243,7 +243,6 @@ static int slh_dsa_sign_verify_test(int tst_id)
             || !TEST_int_eq(EVP_PKEY_sign(sctx, psig, &psig_len,
                                           td->msg, td->msg_len), 1))
         goto err;
-
     if (!TEST_int_eq(EVP_Q_digest(lib_ctx, "SHA256", NULL, psig, psig_len,
                                   digest, &digest_len), 1))
         goto err;