]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
SLH-DSA speed up hash calculations. master
authorslontis <shane.lontis@oracle.com>
Fri, 17 Oct 2025 05:32:06 +0000 (16:32 +1100)
committerslontis <shane.lontis@oracle.com>
Tue, 17 Feb 2026 01:20:04 +0000 (12:20 +1100)
SLH-DSA spends a significant amount of time performing large
numbers of hash calculations. Initially this was done using
EVP layer calls. The overhead is significant when there are thousands
of calls. To reduce this overhead the lower level sha functions for
KECCAK1600_CTX, SHA256_CTX and SHA512_CTX are accessed directly.

Profiling showed that a significant amount of time is spent in
"WOTS+ Public key generation" (FIPS 205 Section 5.1 Algorithm 6) so
this was inlined for shake and sha2 (See slh_wots_pk_gen_sha2()).

In FIPS 205 Section 11 there is a list of Hash functions.
Many of these functions use a pattern of
Trunc(n)(SHA256(PK.Seed || toByte(0, 64-n) || ....)
Because this operation is done many times, this prehashed
value is calculated once and stored into a low level SHA256_CTX or
KECCAK1600_CTX.
This can then be block copied to stack based KECCAK1600_CTX or
SHA256_CTX that we can then perform low level SHA functions on.
The md_len field is written to directly before the SHA final() to
control the length of the output (which avoids performing a memcpy).

Reviewed-by: Paul Dale <paul.dale@oracle.com>
Reviewed-by: Viktor Dukhovni <viktor@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/28941)

12 files changed:
crypto/sha/sha3.c
crypto/sha/sha512.c
crypto/slh_dsa/slh_dsa_hash_ctx.c
crypto/slh_dsa/slh_dsa_key.c
crypto/slh_dsa/slh_dsa_key.h
crypto/slh_dsa/slh_dsa_local.h
crypto/slh_dsa/slh_hash.c
crypto/slh_dsa/slh_hash.h
crypto/slh_dsa/slh_wots.c
include/crypto/md32_common.h
include/crypto/slh_dsa.h
include/internal/sha3.h

index 226f9113066f2e6e7dced648dc28aeff528f7a61..1b1104e05b3871b3be749f4b12c2f8e211b1b8d6 100644 (file)
@@ -8,11 +8,21 @@
  */
 
 #include <string.h>
+#include "internal/sha3.h"
+#include "internal/common.h"
+
+#if defined(__aarch64__) && defined(KECCAK1600_ASM)
+#include "crypto/arm_arch.h"
+#endif
+
 #if defined(__s390x__) && defined(OPENSSL_CPUID_OBJ)
 #include "crypto/s390x_arch.h"
+#if defined(KECCAK1600_ASM)
+#define S390_SHA3 1
+#define S390_SHA3_CAPABLE(name) \
+    ((OPENSSL_s390xcap_P.kimd[0] & S390X_CAPBIT(name)) && (OPENSSL_s390xcap_P.klmd[0] & S390X_CAPBIT(name)))
+#endif
 #endif
-#include "internal/sha3.h"
-#include "internal/common.h"
 
 void SHA3_squeeze(uint64_t A[5][5], unsigned char *out, size_t len, size_t r, int next);
 
@@ -158,7 +168,7 @@ int ossl_sha3_absorb(KECCAK1600_CTX *ctx, const unsigned char *inp, size_t len)
     /* Absorb the input - rem = leftover part of the input < blocksize) */
     rem = ctx->meth.absorb(ctx, inp, len);
     /* Copy the leftover bit of the input into the buffer */
-    if (ossl_likely(rem)) {
+    if (ossl_likely(rem > 0)) {
         memcpy(ctx->buf, inp + len - rem, rem);
         ctx->bufsz = rem;
     }
@@ -243,3 +253,155 @@ int ossl_sha3_squeeze(KECCAK1600_CTX *ctx, unsigned char *out, size_t outlen)
 
     return 1;
 }
+
+/*-
+ * Generic software version of the absorb() and final().
+ */
+static size_t generic_sha3_absorb(void *vctx, const void *inp, size_t len)
+{
+    KECCAK1600_CTX *ctx = vctx;
+
+    if (!(ctx->xof_state == XOF_STATE_INIT || ctx->xof_state == XOF_STATE_ABSORB))
+        return 0;
+    ctx->xof_state = XOF_STATE_ABSORB;
+    return SHA3_absorb(ctx->A, inp, len, ctx->block_size);
+}
+
+static int generic_sha3_final(void *vctx, unsigned char *out, size_t outlen)
+{
+    return ossl_sha3_final((KECCAK1600_CTX *)vctx, out, outlen);
+}
+
+static int generic_sha3_squeeze(void *vctx, unsigned char *out, size_t outlen)
+{
+    return ossl_sha3_squeeze((KECCAK1600_CTX *)vctx, out, outlen);
+}
+
+static PROV_SHA3_METHOD shake_generic_meth = {
+    generic_sha3_absorb,
+    generic_sha3_final,
+    generic_sha3_squeeze
+};
+
+#if defined(S390_SHA3)
+
+/*-
+ * The platform specific parts of the absorb() and final() for S390X.
+ */
+static size_t s390x_sha3_absorb(void *vctx, const void *inp, size_t len)
+{
+    KECCAK1600_CTX *ctx = vctx;
+    size_t rem = len % ctx->block_size;
+    unsigned int fc;
+
+    if (!(ctx->xof_state == XOF_STATE_INIT || ctx->xof_state == XOF_STATE_ABSORB))
+        return 0;
+    if (len - rem > 0) {
+        fc = ctx->pad;
+        fc |= ctx->xof_state == XOF_STATE_INIT ? S390X_KIMD_NIP : 0;
+        ctx->xof_state = XOF_STATE_ABSORB;
+        s390x_kimd(inp, len - rem, fc, ctx->A);
+    }
+    return rem;
+}
+
+static int s390x_shake_final(void *vctx, unsigned char *out, size_t outlen)
+{
+    KECCAK1600_CTX *ctx = vctx;
+    unsigned int fc;
+
+    if (!(ctx->xof_state == XOF_STATE_INIT || ctx->xof_state == XOF_STATE_ABSORB))
+        return 0;
+    fc = ctx->pad | S390X_KLMD_DUFOP;
+    fc |= ctx->xof_state == XOF_STATE_INIT ? S390X_KLMD_NIP : 0;
+    ctx->xof_state = XOF_STATE_FINAL;
+    s390x_klmd(ctx->buf, ctx->bufsz, out, outlen, fc, ctx->A);
+    return 1;
+}
+
+static int s390x_shake_squeeze(void *vctx, unsigned char *out, size_t outlen)
+{
+    KECCAK1600_CTX *ctx = vctx;
+    unsigned int fc;
+    size_t len;
+
+    if (ctx->xof_state == XOF_STATE_FINAL)
+        return 0;
+    /*
+     * On the first squeeze call, finish the absorb process (incl. padding).
+     */
+    if (ctx->xof_state != XOF_STATE_SQUEEZE) {
+        fc = ctx->pad;
+        fc |= ctx->xof_state == XOF_STATE_INIT ? S390X_KLMD_NIP : 0;
+        ctx->xof_state = XOF_STATE_SQUEEZE;
+        s390x_klmd(ctx->buf, ctx->bufsz, out, outlen, fc, ctx->A);
+        ctx->bufsz = outlen % ctx->block_size;
+        /* reuse ctx->bufsz to count bytes squeezed from current sponge */
+        return 1;
+    }
+    ctx->xof_state = XOF_STATE_SQUEEZE;
+    if (ctx->bufsz != 0) {
+        len = ctx->block_size - ctx->bufsz;
+        if (outlen < len)
+            len = outlen;
+        memcpy(out, (char *)ctx->A + ctx->bufsz, len);
+        out += len;
+        outlen -= len;
+        ctx->bufsz += len;
+        if (ctx->bufsz == ctx->block_size)
+            ctx->bufsz = 0;
+    }
+    if (outlen == 0)
+        return 1;
+    s390x_klmd(NULL, 0, out, outlen, ctx->pad | S390X_KLMD_PS, ctx->A);
+    ctx->bufsz = outlen % ctx->block_size;
+
+    return 1;
+}
+
+static PROV_SHA3_METHOD shake_s390x_meth = {
+    s390x_sha3_absorb,
+    s390x_shake_final,
+    s390x_shake_squeeze,
+};
+#elif defined(__aarch64__) && defined(KECCAK1600_ASM)
+
+size_t SHA3_absorb_cext(uint64_t A[5][5], const unsigned char *inp, size_t len,
+    size_t r);
+/*-
+ * Hardware-assisted ARMv8.2 SHA3 extension version of the absorb()
+ */
+static size_t armsha3_sha3_absorb(void *vctx, const void *inp, size_t len)
+{
+    KECCAK1600_CTX *ctx = vctx;
+
+    return SHA3_absorb_cext(ctx->A, inp, len, ctx->block_size);
+}
+
+static PROV_SHA3_METHOD shake_ARMSHA3_meth = {
+    armsha3_sha3_absorb,
+    generic_sha3_final,
+    generic_sha3_squeeze
+};
+#endif
+
+KECCAK1600_CTX *ossl_shake256_new(void)
+{
+    KECCAK1600_CTX *ctx = OPENSSL_zalloc(sizeof(*ctx));
+
+    if (ctx == NULL)
+        return NULL;
+    ossl_keccak_init(ctx, '\x1f', 256, 0);
+    ctx->md_size = SIZE_MAX;
+    ctx->meth = shake_generic_meth;
+#if defined(S390_SHA3)
+    if (S390_SHA3_CAPABLE(S390X_SHAKE_256)) {
+        ctx->pad = S390X_SHAKE_256;
+        ctx->meth = shake_s390x_meth;
+    }
+#elif defined(__aarch64__) && defined(KECCAK1600_ASM)
+    if (OPENSSL_armcap_P & ARMV8_HAVE_SHA3_AND_WORTH_USING)
+        ctx->meth = shake_ARMSHA3_meth;
+#endif
+    return ctx;
+}
index 09bf21771ae12ded79010e5721c25cf75aae5fe1..23a65177fef3322ca84a382a68a6ec658fa11c24 100644 (file)
@@ -15,6 +15,7 @@
 
 #include <stdio.h>
 #include <openssl/opensslconf.h>
+#include <openssl/byteorder.h>
 /*-
  * IMPLEMENTATION NOTES.
  *
@@ -154,7 +155,11 @@ void sha512_block_data_order_c(SHA512_CTX *ctx, const void *in, size_t num);
 #endif
     void sha512_block_data_order(SHA512_CTX *ctx, const void *in, size_t num);
 
-int SHA512_Final(unsigned char *md, SHA512_CTX *c)
+#define OUTPUT_RESULT(md, len)      \
+    for (n = 0; n < (len / 8); n++) \
+    md = OPENSSL_store_u64_be(md, (uint64_t)c->h[n])
+
+int SHA512_Final(unsigned char *out, SHA512_CTX *c)
 {
     unsigned char *p = (unsigned char *)c->u.p;
     size_t n = c->num;
@@ -172,58 +177,33 @@ int SHA512_Final(unsigned char *md, SHA512_CTX *c)
     c->u.d[SHA_LBLOCK - 2] = c->Nh;
     c->u.d[SHA_LBLOCK - 1] = c->Nl;
 #else
-    p[sizeof(c->u) - 1] = (unsigned char)(c->Nl);
-    p[sizeof(c->u) - 2] = (unsigned char)(c->Nl >> 8);
-    p[sizeof(c->u) - 3] = (unsigned char)(c->Nl >> 16);
-    p[sizeof(c->u) - 4] = (unsigned char)(c->Nl >> 24);
-    p[sizeof(c->u) - 5] = (unsigned char)(c->Nl >> 32);
-    p[sizeof(c->u) - 6] = (unsigned char)(c->Nl >> 40);
-    p[sizeof(c->u) - 7] = (unsigned char)(c->Nl >> 48);
-    p[sizeof(c->u) - 8] = (unsigned char)(c->Nl >> 56);
-    p[sizeof(c->u) - 9] = (unsigned char)(c->Nh);
-    p[sizeof(c->u) - 10] = (unsigned char)(c->Nh >> 8);
-    p[sizeof(c->u) - 11] = (unsigned char)(c->Nh >> 16);
-    p[sizeof(c->u) - 12] = (unsigned char)(c->Nh >> 24);
-    p[sizeof(c->u) - 13] = (unsigned char)(c->Nh >> 32);
-    p[sizeof(c->u) - 14] = (unsigned char)(c->Nh >> 40);
-    p[sizeof(c->u) - 15] = (unsigned char)(c->Nh >> 48);
-    p[sizeof(c->u) - 16] = (unsigned char)(c->Nh >> 56);
+    uint8_t *cu = p + sizeof(c->u) - 16;
+
+    cu = OPENSSL_store_u64_be(cu, (uint64_t)c->Nh);
+    cu = OPENSSL_store_u64_be(cu, (uint64_t)c->Nl);
 #endif
 
     sha512_block_data_order(c, p, 1);
 
-    if (md == 0)
+    if (out == NULL)
         return 0;
 
+    /* Let compiler decide if it's appropriate to unroll... */
     switch (c->md_len) {
     case SHA256_192_DIGEST_LENGTH:
-        for (n = 0; n < SHA256_192_DIGEST_LENGTH / 8; n++) {
-            SHA_LONG64 t = c->h[n];
-
-            *(md++) = (unsigned char)(t >> 56);
-            *(md++) = (unsigned char)(t >> 48);
-            *(md++) = (unsigned char)(t >> 40);
-            *(md++) = (unsigned char)(t >> 32);
-            *(md++) = (unsigned char)(t >> 24);
-            *(md++) = (unsigned char)(t >> 16);
-            *(md++) = (unsigned char)(t >> 8);
-            *(md++) = (unsigned char)(t);
-        }
+        OUTPUT_RESULT(out, SHA256_192_DIGEST_LENGTH);
         break;
-    /* Let compiler decide if it's appropriate to unroll... */
-    case SHA224_DIGEST_LENGTH:
-        for (n = 0; n < SHA224_DIGEST_LENGTH / 8; n++) {
-            SHA_LONG64 t = c->h[n];
-
-            *(md++) = (unsigned char)(t >> 56);
-            *(md++) = (unsigned char)(t >> 48);
-            *(md++) = (unsigned char)(t >> 40);
-            *(md++) = (unsigned char)(t >> 32);
-            *(md++) = (unsigned char)(t >> 24);
-            *(md++) = (unsigned char)(t >> 16);
-            *(md++) = (unsigned char)(t >> 8);
-            *(md++) = (unsigned char)(t);
-        }
+    case SHA256_DIGEST_LENGTH:
+        OUTPUT_RESULT(out, SHA256_DIGEST_LENGTH);
+        break;
+    case SHA384_DIGEST_LENGTH:
+        OUTPUT_RESULT(out, SHA384_DIGEST_LENGTH);
+        break;
+    case SHA512_DIGEST_LENGTH:
+        OUTPUT_RESULT(out, SHA512_DIGEST_LENGTH);
+        break;
+    case SHA224_DIGEST_LENGTH: {
+        OUTPUT_RESULT(out, SHA224_DIGEST_LENGTH);
         /*
          * For 224 bits, there are four bytes left over that have to be
          * processed separately.
@@ -231,54 +211,13 @@ int SHA512_Final(unsigned char *md, SHA512_CTX *c)
         {
             SHA_LONG64 t = c->h[SHA224_DIGEST_LENGTH / 8];
 
-            *(md++) = (unsigned char)(t >> 56);
-            *(md++) = (unsigned char)(t >> 48);
-            *(md++) = (unsigned char)(t >> 40);
-            *(md++) = (unsigned char)(t >> 32);
-        }
-        break;
-    case SHA256_DIGEST_LENGTH:
-        for (n = 0; n < SHA256_DIGEST_LENGTH / 8; n++) {
-            SHA_LONG64 t = c->h[n];
-
-            *(md++) = (unsigned char)(t >> 56);
-            *(md++) = (unsigned char)(t >> 48);
-            *(md++) = (unsigned char)(t >> 40);
-            *(md++) = (unsigned char)(t >> 32);
-            *(md++) = (unsigned char)(t >> 24);
-            *(md++) = (unsigned char)(t >> 16);
-            *(md++) = (unsigned char)(t >> 8);
-            *(md++) = (unsigned char)(t);
-        }
-        break;
-    case SHA384_DIGEST_LENGTH:
-        for (n = 0; n < SHA384_DIGEST_LENGTH / 8; n++) {
-            SHA_LONG64 t = c->h[n];
-
-            *(md++) = (unsigned char)(t >> 56);
-            *(md++) = (unsigned char)(t >> 48);
-            *(md++) = (unsigned char)(t >> 40);
-            *(md++) = (unsigned char)(t >> 32);
-            *(md++) = (unsigned char)(t >> 24);
-            *(md++) = (unsigned char)(t >> 16);
-            *(md++) = (unsigned char)(t >> 8);
-            *(md++) = (unsigned char)(t);
-        }
-        break;
-    case SHA512_DIGEST_LENGTH:
-        for (n = 0; n < SHA512_DIGEST_LENGTH / 8; n++) {
-            SHA_LONG64 t = c->h[n];
-
-            *(md++) = (unsigned char)(t >> 56);
-            *(md++) = (unsigned char)(t >> 48);
-            *(md++) = (unsigned char)(t >> 40);
-            *(md++) = (unsigned char)(t >> 32);
-            *(md++) = (unsigned char)(t >> 24);
-            *(md++) = (unsigned char)(t >> 16);
-            *(md++) = (unsigned char)(t >> 8);
-            *(md++) = (unsigned char)(t);
+            *(out++) = (unsigned char)(t >> 56);
+            *(out++) = (unsigned char)(t >> 48);
+            *(out++) = (unsigned char)(t >> 40);
+            *(out++) = (unsigned char)(t >> 32);
         }
         break;
+    }
     /* ... as well as make sure md_len is not abused. */
     default:
         return 0;
index 9dca01acf5ffc4900f8df07da0007c427f9bbed3..b960e24f1de51a88d55dd94da8050d5d18458d42 100644 (file)
@@ -11,6 +11,8 @@
 #include "slh_dsa_local.h"
 #include "slh_dsa_key.h"
 #include <openssl/evp.h>
+#include <openssl/sha.h>
+#include "crypto/evp.h"
 
 /**
  * @brief Create a SLH_DSA_HASH_CTX that contains parameters, functions, and
@@ -31,23 +33,10 @@ SLH_DSA_HASH_CTX *ossl_slh_dsa_hash_ctx_new(const SLH_DSA_KEY *key)
         return NULL;
 
     ret->key = key;
-    ret->md_ctx = EVP_MD_CTX_new();
-    if (ret->md_ctx == NULL)
+    if (key->pub != NULL
+        && !ossl_slh_dsa_hash_ctx_prehash_pk_seed(ret, SLH_DSA_PK_SEED(key), key->params->n))
         goto err;
-    if (EVP_DigestInit_ex2(ret->md_ctx, key->md, NULL) != 1)
-        goto err;
-    if (key->md_big != NULL) {
-        /* Gets here for SHA2 algorithms */
-        if (key->md_big == key->md) {
-            ret->md_big_ctx = ret->md_ctx;
-        } else {
-            /* Only gets here for SHA2 */
-            ret->md_big_ctx = EVP_MD_CTX_new();
-            if (ret->md_big_ctx == NULL)
-                goto err;
-            if (EVP_DigestInit_ex2(ret->md_big_ctx, key->md_big, NULL) != 1)
-                goto err;
-        }
+    if (!key->params->is_shake) {
         if (key->hmac != NULL) {
             ret->hmac_ctx = EVP_MAC_CTX_new(key->hmac);
             if (ret->hmac_ctx == NULL)
@@ -76,17 +65,8 @@ SLH_DSA_HASH_CTX *ossl_slh_dsa_hash_ctx_dup(const SLH_DSA_HASH_CTX *src)
     /* Note that the key is not ref counted, since it does not own the key */
     ret->key = src->key;
 
-    if (src->md_ctx != NULL
-        && (ret->md_ctx = EVP_MD_CTX_dup(src->md_ctx)) == NULL)
+    if (!src->key->hash_func->prehash_dup(ret, src))
         goto err;
-    if (src->md_big_ctx != NULL) {
-        if (src->md_big_ctx != src->md_ctx) {
-            if ((ret->md_big_ctx = EVP_MD_CTX_dup(src->md_big_ctx)) == NULL)
-                goto err;
-        } else {
-            ret->md_big_ctx = ret->md_ctx;
-        }
-    }
     if (src->hmac_ctx != NULL
         && (ret->hmac_ctx = EVP_MAC_CTX_dup(src->hmac_ctx)) == NULL)
         goto err;
@@ -96,6 +76,19 @@ err:
     return NULL;
 }
 
+/**
+ * @brief Cache the pk seed.
+ * SLH_DSA performs a large number of hash operations that consist of either
+ *  SHAKE256(PK.seed || .. ) OR
+ *  SHA256(PK.seed || toByte(0, 64 - n) || ...)
+ * So cache this value and reuse it as the starting point for many hash functions.
+ */
+int ossl_slh_dsa_hash_ctx_prehash_pk_seed(SLH_DSA_HASH_CTX *ctx,
+    const uint8_t *pkseed, size_t n)
+{
+    return ctx->key->hash_func->prehash_pk_seed(ctx, pkseed, n);
+}
+
 /**
  * @brief Destroy a SLH_DSA_HASH_CTX
  *
@@ -105,9 +98,8 @@ void ossl_slh_dsa_hash_ctx_free(SLH_DSA_HASH_CTX *ctx)
 {
     if (ctx == NULL)
         return;
-    EVP_MD_CTX_free(ctx->md_ctx);
-    if (ctx->md_big_ctx != ctx->md_ctx)
-        EVP_MD_CTX_free(ctx->md_big_ctx);
+    OPENSSL_free(ctx->shactx);
+    OPENSSL_free(ctx->shactx_pkseed);
     EVP_MAC_CTX_free(ctx->hmac_ctx);
     OPENSSL_free(ctx);
 }
index 9df2d75a861e7e95f9be173fb07a0e75da08f105..efc83259eda1addb82a48c898dc6c673bf51276e 100644 (file)
@@ -23,12 +23,11 @@ static int slh_dsa_compute_pk_root(SLH_DSA_HASH_CTX *ctx, SLH_DSA_KEY *out, int
 static void slh_dsa_key_hash_cleanup(SLH_DSA_KEY *key)
 {
     OPENSSL_free(key->propq);
-    if (key->md_big != key->md)
-        EVP_MD_free(key->md_big);
-    key->md_big = NULL;
+    EVP_MD_free(key->md_sha512);
     EVP_MD_free(key->md);
     EVP_MAC_free(key->hmac);
     key->md = NULL;
+    key->md_sha512 = NULL;
 }
 
 static int slh_dsa_key_hash_init(SLH_DSA_KEY *key)
@@ -45,13 +44,10 @@ static int slh_dsa_key_hash_init(SLH_DSA_KEY *key)
      * SHAKE algorithm(s) use SHAKE for all functions.
      */
     if (is_shake == 0) {
-        if (security_category == 1) {
-            /* For category 1 SHA2-256 is used for all hash operations */
-            key->md_big = key->md;
-        } else {
+        if (security_category != 1) {
             /* Security categories 3 & 5 also need SHA-512 */
-            key->md_big = EVP_MD_fetch(key->libctx, "SHA2-512", key->propq);
-            if (key->md_big == NULL)
+            key->md_sha512 = EVP_MD_fetch(key->libctx, "SHA2-512", key->propq);
+            if (key->md_sha512 == NULL)
                 goto err;
         }
         key->hmac = EVP_MAC_fetch(key->libctx, "HMAC", key->propq);
@@ -59,7 +55,7 @@ static int slh_dsa_key_hash_init(SLH_DSA_KEY *key)
             goto err;
     }
     key->adrs_func = ossl_slh_get_adrs_fn(is_shake == 0);
-    key->hash_func = ossl_slh_get_hash_fn(is_shake);
+    key->hash_func = ossl_slh_get_hash_fn(is_shake, security_category);
     return 1;
 err:
     slh_dsa_key_hash_cleanup(key);
@@ -68,8 +64,8 @@ err:
 
 static void slh_dsa_key_hash_dup(SLH_DSA_KEY *dst, const SLH_DSA_KEY *src)
 {
-    if (src->md_big != NULL && src->md_big != src->md)
-        EVP_MD_up_ref(src->md_big);
+    if (src->md_sha512 != NULL)
+        EVP_MD_up_ref(src->md_sha512);
     if (src->md != NULL)
         EVP_MD_up_ref(src->md);
     if (src->hmac != NULL)
@@ -382,7 +378,9 @@ int ossl_slh_dsa_generate_key(SLH_DSA_HASH_CTX *ctx, SLH_DSA_KEY *out,
             || RAND_bytes_ex(lib_ctx, pub, pk_seed_len, 0) <= 0)
             goto err;
     }
-    if (!slh_dsa_compute_pk_root(ctx, out, 0))
+
+    if (!ossl_slh_dsa_hash_ctx_prehash_pk_seed(ctx, pub, pk_seed_len)
+        || !slh_dsa_compute_pk_root(ctx, out, 0))
         goto err;
     out->pub = pub;
     out->has_priv = 1;
index 37b7aa1b16ad6ffdec50e58981b1e999241d76cd..b3bbd740d38c3357064c6a68cd33ca03896c17a3 100644 (file)
@@ -44,7 +44,7 @@ struct slh_dsa_key_st {
     const SLH_HASH_FUNC *hash_func;
     /* See FIPS 205 Section 11.1 */
 
-    EVP_MD *md; /* Used for SHAKE and SHA-256 */
-    EVP_MD *md_big; /* Used for SHA-256 or SHA-512 */
+    EVP_MD *md; /* Used for general SHAKE and SHA-256 hashes */
+    EVP_MD *md_sha512; /* Used for SHA-512 hashes */
     EVP_MAC *hmac;
 };
index 57dfc1eb13011b5da33ebcc123b89253417c527a..b717e51a87999d836e16e9db842504ae0d7fe35c 100644 (file)
@@ -47,8 +47,8 @@
  */
 struct slh_dsa_hash_ctx_st {
     const SLH_DSA_KEY *key; /* This key is not owned by this object */
-    EVP_MD_CTX *md_ctx; /* Either SHAKE OR SHA-256 */
-    EVP_MD_CTX *md_big_ctx; /* Either SHA-512 or points to |md_ctx| for SHA-256*/
+    void *shactx; /* A low level SHAKE object */
+    void *shactx_pkseed; /* A low level SHAKE or SHA256 object with PK.seed hashed in it */
     EVP_MAC_CTX *hmac_ctx; /* required by SHA algorithms for PRFmsg() */
     int hmac_digest_used; /* Used for lazy init of hmac_ctx digest */
 };
@@ -63,11 +63,11 @@ __owur int ossl_slh_wots_pk_from_sig(SLH_DSA_HASH_CTX *ctx,
     PACKET *sig_rpkt, const uint8_t *msg,
     const uint8_t *pk_seed, uint8_t *adrs,
     uint8_t *pk_out, size_t pk_out_len);
-
 __owur int ossl_slh_xmss_node(SLH_DSA_HASH_CTX *ctx, const uint8_t *sk_seed,
     uint32_t node_id, uint32_t height,
     const uint8_t *pk_seed, uint8_t *adrs,
     uint8_t *pk_out, size_t pk_out_len);
+
 __owur int ossl_slh_xmss_sign(SLH_DSA_HASH_CTX *ctx, const uint8_t *msg,
     const uint8_t *sk_seed, uint32_t node_id,
     const uint8_t *pk_seed, uint8_t *adrs,
index 951019819c16860db6530023486189ccf8bd3750..1deed75a3228a9191aefbaa50b879a71e80389ee 100644 (file)
 #include "slh_dsa_local.h"
 #include "slh_dsa_key.h"
 
+#include "openssl/sha.h"
+#include "internal/sha3.h"
+#include "crypto/evp.h"
+#include "crypto/sha.h"
+
 #define MAX_DIGEST_SIZE 64 /* SHA-512 is used for security category 3 & 5 */
+#define NIBBLE_MASK 15
 
-static OSSL_SLH_HASHFUNC_H_MSG slh_hmsg_sha2;
-static OSSL_SLH_HASHFUNC_PRF slh_prf_sha2;
-static OSSL_SLH_HASHFUNC_PRF_MSG slh_prf_msg_sha2;
-static OSSL_SLH_HASHFUNC_F slh_f_sha2;
-static OSSL_SLH_HASHFUNC_H slh_h_sha2;
-static OSSL_SLH_HASHFUNC_T slh_t_sha2;
+/* Most hash functions in SLH-DSA truncate the output */
+#define sha256_final(ctx, out, outlen)    \
+    (ctx)->md_len = (unsigned int)outlen; \
+    SHA256_Final(out, ctx)
 
-static OSSL_SLH_HASHFUNC_H_MSG slh_hmsg_shake;
+#define sha512_final(ctx, out, outlen)    \
+    (ctx)->md_len = (unsigned int)outlen; \
+    SHA512_Final(out, ctx)
+
+static OSSL_SLH_HASHFUNC_PRF slh_prf_sha256;
 static OSSL_SLH_HASHFUNC_PRF slh_prf_shake;
-static OSSL_SLH_HASHFUNC_PRF_MSG slh_prf_msg_shake;
+
+static OSSL_SLH_HASHFUNC_F slh_f_sha256;
 static OSSL_SLH_HASHFUNC_F slh_f_shake;
-static OSSL_SLH_HASHFUNC_H slh_h_shake;
-static OSSL_SLH_HASHFUNC_T slh_t_shake;
 
-static ossl_inline int xof_digest_3(EVP_MD_CTX *ctx,
-    const uint8_t *in1, size_t in1_len,
-    const uint8_t *in2, size_t in2_len,
-    const uint8_t *in3, size_t in3_len,
-    uint8_t *out, size_t out_len)
-{
-    return (EVP_DigestInit_ex2(ctx, NULL, NULL) == 1
-        && EVP_DigestUpdate(ctx, in1, in1_len) == 1
-        && EVP_DigestUpdate(ctx, in2, in2_len) == 1
-        && EVP_DigestUpdate(ctx, in3, in3_len) == 1
-        && EVP_DigestFinalXOF(ctx, out, out_len) == 1);
-}
+static OSSL_SLH_HASHFUNC_PRF_MSG slh_prf_msg_sha2;
+static OSSL_SLH_HASHFUNC_PRF_MSG slh_prf_msg_shake;
 
-static ossl_inline int xof_digest_4(EVP_MD_CTX *ctx,
-    const uint8_t *in1, size_t in1_len,
-    const uint8_t *in2, size_t in2_len,
-    const uint8_t *in3, size_t in3_len,
-    const uint8_t *in4, size_t in4_len,
-    uint8_t *out, size_t out_len)
-{
-    return (EVP_DigestInit_ex2(ctx, NULL, NULL) == 1
-        && EVP_DigestUpdate(ctx, in1, in1_len) == 1
-        && EVP_DigestUpdate(ctx, in2, in2_len) == 1
-        && EVP_DigestUpdate(ctx, in3, in3_len) == 1
-        && EVP_DigestUpdate(ctx, in4, in4_len) == 1
-        && EVP_DigestFinalXOF(ctx, out, out_len) == 1);
-}
+static OSSL_SLH_HASHFUNC_H_MSG slh_hmsg_sha256;
+static OSSL_SLH_HASHFUNC_H_MSG slh_hmsg_sha512;
+static OSSL_SLH_HASHFUNC_H_MSG slh_hmsg_shake;
+
+static OSSL_SLH_HASHFUNC_H slh_h_sha256;
+static OSSL_SLH_HASHFUNC_H slh_h_sha512;
+static OSSL_SLH_HASHFUNC_H slh_h_shake;
+static OSSL_SLH_HASHFUNC_T slh_t_sha256;
+static OSSL_SLH_HASHFUNC_T slh_t_sha512;
+static OSSL_SLH_HASHFUNC_wots_pk_gen slh_wots_pk_gen_sha2;
+static OSSL_SLH_HASHFUNC_wots_pk_gen slh_wots_pk_gen_shake;
+
+static const uint8_t zeros[128] = { 0 };
 
 /* See FIPS 205 Section 11.1 */
 static int
-slh_hmsg_shake(SLH_DSA_HASH_CTX *ctx, const uint8_t *r,
+slh_hmsg_shake(SLH_DSA_HASH_CTX *hctx, const uint8_t *r,
     const uint8_t *pk_seed, const uint8_t *pk_root,
     const uint8_t *msg, size_t msg_len,
     uint8_t *out, size_t out_len)
 {
-    const SLH_DSA_PARAMS *params = ctx->key->params;
+    KECCAK1600_CTX *sctx = (KECCAK1600_CTX *)(hctx->shactx);
+    const SLH_DSA_PARAMS *params = hctx->key->params;
     size_t m = params->m;
     size_t n = params->n;
 
-    return xof_digest_4(ctx->md_ctx, r, n, pk_seed, n, pk_root, n,
-        msg, msg_len, out, m);
+    ossl_sha3_reset(sctx);
+    ossl_sha3_absorb(sctx, r, n);
+    ossl_sha3_absorb(sctx, pk_seed, n);
+    ossl_sha3_absorb(sctx, pk_root, n);
+    ossl_sha3_absorb(sctx, msg, msg_len);
+    ossl_sha3_final(sctx, out, m);
+    return 1;
 }
 
 static int
-slh_prf_shake(SLH_DSA_HASH_CTX *ctx,
-    const uint8_t *pk_seed, const uint8_t *sk_seed,
-    const uint8_t *adrs, uint8_t *out, size_t out_len)
+slh_prf_msg_shake(SLH_DSA_HASH_CTX *hctx, const uint8_t *sk_prf,
+    const uint8_t *opt_rand, const uint8_t *msg, size_t msg_len,
+    WPACKET *pkt)
 {
-    const SLH_DSA_PARAMS *params = ctx->key->params;
+    unsigned char out[SLH_MAX_N];
+    const SLH_DSA_PARAMS *params = hctx->key->params;
     size_t n = params->n;
-
-    return xof_digest_3(ctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE,
-        sk_seed, n, out, n);
+    KECCAK1600_CTX *sctx = (KECCAK1600_CTX *)(hctx->shactx);
+
+    ossl_sha3_reset(sctx);
+    ossl_sha3_absorb(sctx, sk_prf, n);
+    ossl_sha3_absorb(sctx, opt_rand, n);
+    ossl_sha3_absorb(sctx, msg, msg_len);
+    ossl_sha3_final(sctx, out, n);
+    return WPACKET_memcpy(pkt, out, n);
 }
 
 static int
-slh_prf_msg_shake(SLH_DSA_HASH_CTX *ctx, const uint8_t *sk_prf,
-    const uint8_t *opt_rand, const uint8_t *msg, size_t msg_len,
-    WPACKET *pkt)
+slh_f_shake(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
+    const uint8_t *m1, size_t m1_len, uint8_t *out, size_t out_len)
 {
-    unsigned char out[SLH_MAX_N];
-    const SLH_DSA_PARAMS *params = ctx->key->params;
+    const SLH_DSA_PARAMS *params = hctx->key->params;
     size_t n = params->n;
+    KECCAK1600_CTX sctx = *((KECCAK1600_CTX *)(hctx->shactx_pkseed));
 
-    return xof_digest_3(ctx->md_ctx, sk_prf, n, opt_rand, n, msg, msg_len, out, n)
-        && WPACKET_memcpy(pkt, out, n);
+    ossl_sha3_absorb(&sctx, adrs, SLH_ADRS_SIZE);
+    ossl_sha3_absorb(&sctx, m1, m1_len);
+    ossl_sha3_final(&sctx, out, n);
+    return 1;
 }
 
 static int
-slh_f_shake(SLH_DSA_HASH_CTX *ctx, const uint8_t *pk_seed, const uint8_t *adrs,
-    const uint8_t *m1, size_t m1_len, uint8_t *out, size_t out_len)
+slh_prf_shake(SLH_DSA_HASH_CTX *hctx,
+    const uint8_t *pk_seed, const uint8_t *sk_seed,
+    const uint8_t *adrs, uint8_t *out, size_t out_len)
 {
-    const SLH_DSA_PARAMS *params = ctx->key->params;
+    const SLH_DSA_PARAMS *params = hctx->key->params;
     size_t n = params->n;
+    KECCAK1600_CTX sctx = *((KECCAK1600_CTX *)(hctx->shactx_pkseed));
 
-    return xof_digest_3(ctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE, m1, m1_len, out, n);
+    ossl_sha3_absorb(&sctx, adrs, SLH_ADRS_SIZE);
+    ossl_sha3_absorb(&sctx, sk_seed, n);
+    ossl_sha3_final(&sctx, out, n);
+    return 1;
 }
 
 static int
-slh_h_shake(SLH_DSA_HASH_CTX *ctx, const uint8_t *pk_seed, const uint8_t *adrs,
+slh_h_shake(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
     const uint8_t *m1, const uint8_t *m2, uint8_t *out, size_t out_len)
 {
-    const SLH_DSA_PARAMS *params = ctx->key->params;
+    KECCAK1600_CTX ctx = *((KECCAK1600_CTX *)(hctx->shactx_pkseed)), *sctx = &ctx;
+    const SLH_DSA_PARAMS *params = hctx->key->params;
     size_t n = params->n;
 
-    return xof_digest_4(ctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE, m1, n, m2, n, out, n);
+    ossl_sha3_absorb(sctx, adrs, SLH_ADRS_SIZE);
+    ossl_sha3_absorb(sctx, m1, n);
+    ossl_sha3_absorb(sctx, m2, n);
+    ossl_sha3_final(sctx, out, n);
+    return 1;
 }
 
+/* FIPS 205 Section 11.2.1 and 11.2.2 */
+
 static int
-slh_t_shake(SLH_DSA_HASH_CTX *ctx, const uint8_t *pk_seed, const uint8_t *adrs,
-    const uint8_t *ml, size_t ml_len, uint8_t *out, size_t out_len)
+slh_hmsg_sha256(SLH_DSA_HASH_CTX *hctx, const uint8_t *r, const uint8_t *pk_seed,
+    const uint8_t *pk_root, const uint8_t *msg, size_t msg_len,
+    uint8_t *out, size_t out_len)
 {
-    const SLH_DSA_PARAMS *params = ctx->key->params;
+    SHA256_CTX ctx, *sctx = &ctx;
+    const SLH_DSA_PARAMS *params = hctx->key->params;
+    size_t m = params->m;
     size_t n = params->n;
+    uint8_t seed[2 * SLH_MAX_N + SHA256_DIGEST_LENGTH];
+    long seed_len = SHA256_DIGEST_LENGTH + (long)(2 * n);
 
-    return xof_digest_3(ctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE, ml, ml_len, out, n);
-}
+    memcpy(seed, r, n);
+    memcpy(seed + n, pk_seed, n);
 
-static ossl_inline int
-digest_4(EVP_MD_CTX *ctx,
-    const uint8_t *in1, size_t in1_len, const uint8_t *in2, size_t in2_len,
-    const uint8_t *in3, size_t in3_len, const uint8_t *in4, size_t in4_len,
-    uint8_t *out)
-{
-    return (EVP_DigestInit_ex2(ctx, NULL, NULL) == 1
-        && EVP_DigestUpdate(ctx, in1, in1_len) == 1
-        && EVP_DigestUpdate(ctx, in2, in2_len) == 1
-        && EVP_DigestUpdate(ctx, in3, in3_len) == 1
-        && EVP_DigestUpdate(ctx, in4, in4_len) == 1
-        && EVP_DigestFinal_ex(ctx, out, NULL) == 1);
+    SHA256_Init(sctx);
+    SHA256_Update(sctx, r, n);
+    SHA256_Update(sctx, pk_seed, n);
+    SHA256_Update(sctx, pk_root, n);
+    SHA256_Update(sctx, msg, msg_len);
+    return SHA256_Final(seed + 2 * n, sctx)
+        && (PKCS1_MGF1(out, (long)m, seed, seed_len, hctx->key->md) == 0);
 }
 
-/* FIPS 205 Section 11.2.1 and 11.2.2 */
-
 static int
-slh_hmsg_sha2(SLH_DSA_HASH_CTX *hctx, const uint8_t *r, const uint8_t *pk_seed,
+slh_hmsg_sha512(SLH_DSA_HASH_CTX *hctx, const uint8_t *r, const uint8_t *pk_seed,
     const uint8_t *pk_root, const uint8_t *msg, size_t msg_len,
     uint8_t *out, size_t out_len)
 {
+    SHA512_CTX ctx, *sctx = &ctx;
     const SLH_DSA_PARAMS *params = hctx->key->params;
     size_t m = params->m;
     size_t n = params->n;
-    uint8_t seed[2 * SLH_MAX_N + MAX_DIGEST_SIZE];
-    int sz = EVP_MD_get_size(hctx->key->md_big);
-    size_t seed_len = (size_t)sz + 2 * n;
-
-    if (sz <= 0)
-        return 0;
+    uint8_t seed[2 * SLH_MAX_N + SHA512_DIGEST_LENGTH];
+    long seed_len = SHA512_DIGEST_LENGTH + (long)(2 * n);
 
     memcpy(seed, r, n);
     memcpy(seed + n, pk_seed, n);
-    return digest_4(hctx->md_big_ctx, r, n, pk_seed, n, pk_root, n, msg, msg_len,
-               seed + 2 * n)
-        && (PKCS1_MGF1(out, (long)m, seed, (long)seed_len, hctx->key->md_big) == 0);
+
+    SHA512_Init(sctx);
+    SHA512_Update(sctx, r, n);
+    SHA512_Update(sctx, pk_seed, n);
+    SHA512_Update(sctx, pk_root, n);
+    SHA512_Update(sctx, msg, msg_len);
+    return SHA512_Final(seed + 2 * n, sctx)
+        && (PKCS1_MGF1(out, (long)m, seed, seed_len, hctx->key->md_sha512) == 0);
 }
 
 static int
@@ -188,10 +209,11 @@ slh_prf_msg_sha2(SLH_DSA_HASH_CTX *hctx,
      * So we do a lazy update here on the first call.
      */
     if (hctx->hmac_digest_used == 0) {
+        const char *nm = EVP_MD_get0_name(key->md_sha512 == NULL ? key->md : key->md_sha512);
+
         p = params;
         /* The underlying digest to be used */
-        *p++ = OSSL_PARAM_construct_utf8_string(OSSL_MAC_PARAM_DIGEST,
-            (char *)EVP_MD_get0_name(key->md_big), 0);
+        *p++ = OSSL_PARAM_construct_utf8_string(OSSL_MAC_PARAM_DIGEST, (char *)nm, 0);
         if (key->propq != NULL)
             *p++ = OSSL_PARAM_construct_utf8_string(OSSL_MAC_PARAM_PROPERTIES,
                 (char *)key->propq, 0);
@@ -208,79 +230,272 @@ slh_prf_msg_sha2(SLH_DSA_HASH_CTX *hctx,
     return ret;
 }
 
-static ossl_inline int
-do_hash(EVP_MD_CTX *ctx, size_t n, const uint8_t *pk_seed, const uint8_t *adrs,
-    const uint8_t *m, size_t m_len, size_t b, uint8_t *out, size_t out_len)
+static int
+slh_prf_sha256(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed,
+    const uint8_t *sk_seed, const uint8_t *adrs,
+    uint8_t *out, size_t out_len)
 {
-    int ret;
-    uint8_t zeros[128] = { 0 };
-    uint8_t digest[MAX_DIGEST_SIZE];
+    SHA256_CTX ctx = *((SHA256_CTX *)hctx->shactx_pkseed), *sctx = &ctx;
+    size_t n = hctx->key->params->n;
 
-    ret = digest_4(ctx, pk_seed, n, zeros, b - n, adrs, SLH_ADRSC_SIZE,
-        m, m_len, digest);
-    /* Truncated returned value is n = 16 bytes */
-    memcpy(out, digest, n);
-    return ret;
+    SHA256_Update(sctx, adrs, SLH_ADRSC_SIZE);
+    SHA256_Update(sctx, sk_seed, n);
+    sha256_final(sctx, out, n);
+    return 1;
 }
 
 static int
-slh_prf_sha2(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed,
-    const uint8_t *sk_seed, const uint8_t *adrs,
-    uint8_t *out, size_t out_len)
+slh_wots_pk_gen_sha2(SLH_DSA_HASH_CTX *hctx,
+    const uint8_t *sk_seed, const uint8_t *pk_seed,
+    uint8_t *adrs, uint8_t *pk_out, size_t pk_out_len)
 {
+    int ret = 0;
     size_t n = hctx->key->params->n;
+    size_t i, j = 0, len = SLH_WOTS_LEN(n);
+    uint8_t sk[SLH_MAX_N];
+    SHA256_CTX *sctx = (SHA256_CTX *)(hctx->shactx_pkseed);
+    SHA256_CTX ctx;
+    const SLH_ADRS_FUNC *adrsf = hctx->key->adrs_func;
+    SLH_ADRS_DECLARE(sk_adrs);
+    SLH_ADRS_FN_DECLARE(adrsf, set_chain_address);
+    SLH_ADRS_FN_DECLARE(adrsf, set_hash_address);
+
+    adrsf->copy(sk_adrs, adrs);
+    adrsf->set_type_and_clear(sk_adrs, SLH_ADRS_TYPE_WOTS_PRF);
+    adrsf->copy_keypair_address(sk_adrs, adrs);
+
+    for (i = 0; i < len; ++i) { /* len = 2n + 3 */
+        set_chain_address(sk_adrs, (uint32_t)i);
+
+        /* PRF */
+        ctx = *sctx;
+        SHA256_Update(&ctx, sk_adrs, SLH_ADRSC_SIZE);
+        SHA256_Update(&ctx, sk_seed, n);
+        sha256_final(&ctx, sk, n);
+
+        set_chain_address(adrs, (uint32_t)i);
+        for (j = 0; j < NIBBLE_MASK; ++j) {
+            set_hash_address(adrs, (uint32_t)j);
+            /* F */
+            ctx = *sctx;
+            SHA256_Update(&ctx, adrs, SLH_ADRSC_SIZE);
+            SHA256_Update(&ctx, sk, n);
+            sha256_final(&ctx, sk, n);
+        }
+        memcpy(pk_out, sk, n);
+        pk_out += n;
+    }
+    ret = 1;
+    return ret;
+}
 
-    return do_hash(hctx->md_ctx, n, pk_seed, adrs, sk_seed, n,
-        OSSL_SLH_DSA_SHA2_NUM_ZEROS_H_AND_T_BOUND1, out, out_len);
+int slh_wots_pk_gen_shake(SLH_DSA_HASH_CTX *hctx,
+    const uint8_t *sk_seed, const uint8_t *pk_seed,
+    uint8_t *adrs, uint8_t *pk_out, size_t pk_out_len)
+{
+    int ret = 0;
+    size_t n = hctx->key->params->n;
+    size_t i, j = 0, len = SLH_WOTS_LEN(n);
+    uint8_t sk[SLH_MAX_N];
+    const SLH_ADRS_FUNC *adrsf = hctx->key->adrs_func;
+    SLH_ADRS_DECLARE(sk_adrs);
+    SLH_ADRS_FN_DECLARE(adrsf, set_chain_address);
+    SLH_ADRS_FN_DECLARE(adrsf, set_hash_address);
+    KECCAK1600_CTX *sctx = (KECCAK1600_CTX *)(hctx->shactx_pkseed);
+    KECCAK1600_CTX ctx;
+
+    adrsf->copy(sk_adrs, adrs);
+    adrsf->set_type_and_clear(sk_adrs, SLH_ADRS_TYPE_WOTS_PRF);
+    adrsf->copy_keypair_address(sk_adrs, adrs);
+
+    for (i = 0; i < len; ++i) { /* len = 2n + 3 */
+        set_chain_address(sk_adrs, (uint32_t)i);
+
+        /* PRF */
+        ctx = *sctx;
+        ossl_sha3_absorb(&ctx, sk_adrs, SLH_ADRS_SIZE);
+        ossl_sha3_absorb(&ctx, sk_seed, n);
+        ossl_sha3_final(&ctx, sk, n);
+
+        set_chain_address(adrs, (uint32_t)i);
+        for (j = 0; j < NIBBLE_MASK; ++j) {
+            set_hash_address(adrs, (uint32_t)j);
+            /* F */
+            ctx = *sctx;
+            ossl_sha3_absorb(&ctx, adrs, SLH_ADRS_SIZE);
+            ossl_sha3_absorb(&ctx, sk, n);
+            ossl_sha3_final(&ctx, sk, n);
+        }
+        memcpy(pk_out, sk, n);
+        pk_out += n;
+    }
+    ret = 1;
+    return ret;
 }
 
 static int
-slh_f_sha2(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
+slh_f_sha256(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
     const uint8_t *m1, size_t m1_len, uint8_t *out, size_t out_len)
 {
-    return do_hash(hctx->md_ctx, hctx->key->params->n, pk_seed, adrs, m1, m1_len,
-        OSSL_SLH_DSA_SHA2_NUM_ZEROS_H_AND_T_BOUND1, out, out_len);
+    SHA256_CTX ctx = *((SHA256_CTX *)hctx->shactx_pkseed), *sctx = &ctx;
+
+    SHA256_Update(sctx, adrs, SLH_ADRSC_SIZE);
+    SHA256_Update(sctx, m1, m1_len);
+    sha256_final(sctx, out, hctx->key->params->n);
+    return 1;
 }
 
 static int
-slh_h_sha2(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
+slh_h_sha256(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
     const uint8_t *m1, const uint8_t *m2, uint8_t *out, size_t out_len)
 {
-    uint8_t m[SLH_MAX_N * 2];
+    SHA256_CTX ctx = *((SHA256_CTX *)hctx->shactx_pkseed), *sctx = &ctx;
     const SLH_DSA_PARAMS *prms = hctx->key->params;
     size_t n = prms->n;
 
-    memcpy(m, m1, n);
-    memcpy(m + n, m2, n);
-    return do_hash(hctx->md_big_ctx, n, pk_seed, adrs, m, 2 * n,
-        prms->sha2_h_and_t_bound, out, out_len);
+    SHA256_Update(sctx, adrs, SLH_ADRSC_SIZE);
+    SHA256_Update(sctx, m1, n);
+    SHA256_Update(sctx, m2, n);
+    sha256_final(sctx, out, n);
+    return 1;
+}
+
+static int
+slh_h_sha512(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
+    const uint8_t *m1, const uint8_t *m2, uint8_t *out, size_t out_len)
+{
+    SHA512_CTX ctx, *sctx = &ctx;
+    const SLH_DSA_PARAMS *prms = hctx->key->params;
+    size_t n = prms->n;
+
+    SHA512_Init(sctx);
+    SHA512_Update(sctx, pk_seed, n);
+    SHA512_Update(sctx, zeros, 128 - n);
+    SHA512_Update(sctx, adrs, SLH_ADRSC_SIZE);
+    SHA512_Update(sctx, m1, n);
+    SHA512_Update(sctx, m2, n);
+    sha512_final(sctx, out, n);
+    return 1;
+}
+
+static int
+slh_t_sha256(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
+    const uint8_t *ml, size_t ml_len, uint8_t *out, size_t out_len)
+{
+    SHA256_CTX ctx = *((SHA256_CTX *)hctx->shactx_pkseed), *sctx = &ctx;
+
+    SHA256_Update(sctx, adrs, SLH_ADRSC_SIZE);
+    SHA256_Update(sctx, ml, ml_len);
+    sha256_final(sctx, out, hctx->key->params->n);
+    return 1;
 }
 
 static int
-slh_t_sha2(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
+slh_t_sha512(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
     const uint8_t *ml, size_t ml_len, uint8_t *out, size_t out_len)
 {
+    SHA512_CTX ctx, *sctx = &ctx;
     const SLH_DSA_PARAMS *prms = hctx->key->params;
+    size_t n = prms->n;
+
+    SHA512_Init(sctx);
+    SHA512_Update(sctx, pk_seed, n);
+    SHA512_Update(sctx, zeros, 128 - n);
+    SHA512_Update(sctx, adrs, SLH_ADRSC_SIZE);
+    SHA512_Update(sctx, ml, ml_len);
+    sha512_final(sctx, out, hctx->key->params->n);
+    return 1;
+}
+
+static int slh_hash_shake_precache(SLH_DSA_HASH_CTX *hctx, const uint8_t *pkseed, size_t n)
+{
+    KECCAK1600_CTX *ctx = NULL, *seedctx = NULL;
 
-    return do_hash(hctx->md_big_ctx, prms->n, pk_seed, adrs, ml, ml_len,
-        prms->sha2_h_and_t_bound, out, out_len);
+    ctx = ossl_shake256_new();
+    if (ctx == NULL)
+        return 0;
+    seedctx = OPENSSL_memdup(ctx, sizeof(*ctx));
+    if (seedctx == NULL) {
+        OPENSSL_free(ctx);
+        return 0;
+    }
+    ossl_sha3_absorb(seedctx, pkseed, n);
+    hctx->shactx = (void *)ctx;
+    hctx->shactx_pkseed = (void *)seedctx;
+    return 1;
 }
 
-const SLH_HASH_FUNC *ossl_slh_get_hash_fn(int is_shake)
+static int slh_hash_shake_dup(SLH_DSA_HASH_CTX *dst, const SLH_DSA_HASH_CTX *src)
+{
+    if (src->shactx != NULL) {
+        dst->shactx = OPENSSL_memdup(src->shactx, sizeof(KECCAK1600_CTX));
+        if (dst->shactx == NULL)
+            return 0;
+    }
+    if (src->shactx_pkseed != NULL) {
+        dst->shactx_pkseed = OPENSSL_memdup(src->shactx_pkseed, sizeof(KECCAK1600_CTX));
+        if (dst->shactx_pkseed == NULL) {
+            OPENSSL_free(dst->shactx);
+            dst->shactx = NULL;
+            return 0;
+        }
+    }
+    return 1;
+}
+
+static int slh_hash_sha256_precache(SLH_DSA_HASH_CTX *hctx, const uint8_t *pkseed, size_t n)
+{
+    SHA256_CTX *ctx = OPENSSL_zalloc(sizeof(*ctx));
+
+    if (ctx == NULL)
+        return 0;
+    SHA256_Init(ctx);
+    SHA256_Update(ctx, pkseed, n);
+    SHA256_Update(ctx, zeros, 64 - n);
+    hctx->shactx_pkseed = (void *)ctx;
+    return 1;
+}
+
+static int slh_hash_sha256_dup(SLH_DSA_HASH_CTX *dst, const SLH_DSA_HASH_CTX *src)
+{
+    if (src->shactx_pkseed != NULL) {
+        dst->shactx_pkseed = OPENSSL_memdup(src->shactx_pkseed, sizeof(SHA256_CTX));
+        if (dst->shactx_pkseed == NULL)
+            return 0;
+    }
+    return 1;
+}
+
+const SLH_HASH_FUNC *ossl_slh_get_hash_fn(int is_shake, int security_category)
 {
     static const SLH_HASH_FUNC methods[] = {
-        { slh_hmsg_shake,
+        { slh_hash_shake_precache,
+            slh_hash_shake_dup,
+            slh_hmsg_shake,
             slh_prf_shake,
             slh_prf_msg_shake,
             slh_f_shake,
             slh_h_shake,
-            slh_t_shake },
-        { slh_hmsg_sha2,
-            slh_prf_sha2,
+            slh_f_shake,
+            slh_wots_pk_gen_shake },
+        { slh_hash_sha256_precache,
+            slh_hash_sha256_dup,
+            slh_hmsg_sha256,
+            slh_prf_sha256,
+            slh_prf_msg_sha2,
+            slh_f_sha256,
+            slh_h_sha256,
+            slh_t_sha256,
+            slh_wots_pk_gen_sha2 },
+        { slh_hash_sha256_precache,
+            slh_hash_sha256_dup,
+            slh_hmsg_sha512,
+            slh_prf_sha256,
             slh_prf_msg_sha2,
-            slh_f_sha2,
-            slh_h_sha2,
-            slh_t_sha2 }
+            slh_f_sha256,
+            slh_h_sha512,
+            slh_t_sha512,
+            slh_wots_pk_gen_sha2 }
     };
-    return &methods[is_shake ? 0 : 1];
+    return &methods[is_shake ? 0 : (security_category == 1 ? 1 : 2)];
 }
index 51e542139f8c109e7b395e15014ac83666d4b74c..48f3f64e29d2082c6e2d41e56ffd6248656ce232 100644 (file)
@@ -47,20 +47,29 @@ typedef int(OSSL_SLH_HASHFUNC_H)(SLH_DSA_HASH_CTX *ctx, const uint8_t *pk_seed,
     const uint8_t *m1, const uint8_t *m2,
     uint8_t *out, size_t out_len);
 
-typedef int(OSSL_SLH_HASHFUNC_T)(SLH_DSA_HASH_CTX *ctx, const uint8_t *pk_seed,
-    const uint8_t *adrs,
-    const uint8_t *m1, size_t m1_len,
-    uint8_t *out, size_t out_len);
+#define OSSL_SLH_HASHFUNC_T OSSL_SLH_HASHFUNC_F
+
+typedef int(OSSL_SLH_HASHFUNC_wots_pk_gen)(SLH_DSA_HASH_CTX *hctx,
+    const uint8_t *sk_seed, const uint8_t *pk_seed,
+    uint8_t *adrs, uint8_t *pk_out, size_t pk_out_len);
+
+typedef int(OSSL_SLH_HASHFUNC_prehash_pk_seed)(SLH_DSA_HASH_CTX *hctx,
+    const uint8_t *pk_seed, size_t n);
+typedef int(OSSL_SLH_HASHFUNC_prehash_dup)(SLH_DSA_HASH_CTX *dst,
+    const SLH_DSA_HASH_CTX *src);
 
 typedef struct slh_hash_func_st {
+    OSSL_SLH_HASHFUNC_prehash_pk_seed *prehash_pk_seed;
+    OSSL_SLH_HASHFUNC_prehash_dup *prehash_dup;
     OSSL_SLH_HASHFUNC_H_MSG *H_MSG;
     OSSL_SLH_HASHFUNC_PRF *PRF;
     OSSL_SLH_HASHFUNC_PRF_MSG *PRF_MSG;
     OSSL_SLH_HASHFUNC_F *F;
     OSSL_SLH_HASHFUNC_H *H;
     OSSL_SLH_HASHFUNC_T *T;
+    OSSL_SLH_HASHFUNC_wots_pk_gen *wots_pk_gen;
 } SLH_HASH_FUNC;
 
-const SLH_HASH_FUNC *ossl_slh_get_hash_fn(int is_shake);
+const SLH_HASH_FUNC *ossl_slh_get_hash_fn(int is_shake, int security_category);
 
 #endif
index 1c612e8561fd779c4f6356cb973a0b433345ea26..cedd5e7c612c9d9caa108ad869728538ef1ae06b 100644 (file)
@@ -142,45 +142,22 @@ int ossl_slh_wots_pk_gen(SLH_DSA_HASH_CTX *ctx,
     int ret = 0;
     const SLH_DSA_KEY *key = ctx->key;
     size_t n = key->params->n;
-    size_t i, len = SLH_WOTS_LEN(n); /* 2 * n + 3 */
-    uint8_t sk[SLH_MAX_N];
+    size_t len = SLH_WOTS_LEN(n); /* 2 * n + 3 */
     uint8_t tmp[SLH_WOTS_LEN_MAX * SLH_MAX_N];
-    WPACKET pkt, *tmp_wpkt = &pkt; /* Points to the |tmp| buffer */
-    size_t tmp_len = 0;
+    size_t tmp_len = n * len;
 
     SLH_HASH_FUNC_DECLARE(key, hashf);
     SLH_ADRS_FUNC_DECLARE(key, adrsf);
-    SLH_HASH_FN_DECLARE(hashf, PRF);
-    SLH_ADRS_FN_DECLARE(adrsf, set_chain_address);
-    SLH_ADRS_DECLARE(sk_adrs);
     SLH_ADRS_DECLARE(wots_pk_adrs);
 
-    if (!WPACKET_init_static_len(tmp_wpkt, tmp, sizeof(tmp), 0))
-        return 0;
-    adrsf->copy(sk_adrs, adrs);
-    adrsf->set_type_and_clear(sk_adrs, SLH_ADRS_TYPE_WOTS_PRF);
-    adrsf->copy_keypair_address(sk_adrs, adrs);
-
-    for (i = 0; i < len; ++i) { /* len = 2n + 3 */
-        set_chain_address(sk_adrs, (uint32_t)i);
-        if (!PRF(ctx, pk_seed, sk_seed, sk_adrs, sk, sizeof(sk)))
-            goto end;
-
-        set_chain_address(adrs, (uint32_t)i);
-        if (!slh_wots_chain(ctx, sk, 0, NIBBLE_MASK, pk_seed, adrs, tmp_wpkt))
-            goto end;
-    }
-
-    if (!WPACKET_get_total_written(tmp_wpkt, &tmp_len)) /* should be n * (2 * n + 3) */
+    if (!hashf->wots_pk_gen(ctx, sk_seed, pk_seed, adrs, tmp, tmp_len))
         goto end;
+
     adrsf->copy(wots_pk_adrs, adrs);
     adrsf->set_type_and_clear(wots_pk_adrs, SLH_ADRS_TYPE_WOTS_PK);
     adrsf->copy_keypair_address(wots_pk_adrs, adrs);
     ret = hashf->T(ctx, pk_seed, wots_pk_adrs, tmp, tmp_len, pk_out, pk_out_len);
 end:
-    WPACKET_finish(tmp_wpkt);
-    OPENSSL_cleanse(tmp, sizeof(tmp));
-    OPENSSL_cleanse(sk, n);
     return ret;
 }
 
index 7e928a3582a151de3b6f2ef612798caba55a0c9f..d82c0c9033cdbea82f672bb7e03d7738a7564d31 100644 (file)
@@ -241,7 +241,7 @@ int HASH_FINAL(unsigned char *md, HASH_CTX *c)
      * Pad the input by adding a 1 bit + K zero bits + input length (L)
      * as a 64 bit value. K must align the data to a chunk boundary.
      */
-    p[n] = 0x80;                /* there is always room for one */
+    p[n] = 0x80; /* there is always room for one */
     n++;
 
     if (n > (HASH_CBLOCK - 8)) {
index b8ad91fda2ee36da41aefe4f0effbf1cd69ddbab..0d233a5424bd0a67b69a47ae6c48f689ea0870f7 100644 (file)
@@ -57,6 +57,8 @@ __owur int ossl_slh_dsa_key_type_matches(const SLH_DSA_KEY *key, const char *alg
 __owur SLH_DSA_HASH_CTX *ossl_slh_dsa_hash_ctx_new(const SLH_DSA_KEY *key);
 void ossl_slh_dsa_hash_ctx_free(SLH_DSA_HASH_CTX *ctx);
 __owur SLH_DSA_HASH_CTX *ossl_slh_dsa_hash_ctx_dup(const SLH_DSA_HASH_CTX *src);
+__owur int ossl_slh_dsa_hash_ctx_prehash_pk_seed(SLH_DSA_HASH_CTX *ctx,
+    const uint8_t *pkseed, size_t n);
 
 __owur int ossl_slh_dsa_sign(SLH_DSA_HASH_CTX *slh_ctx,
     const uint8_t *msg, size_t msg_len,
index 6b3f76a5b822ca9016f424d697f95e58f8e162db..b71581cf4ace2365eb50c69f70e906457dc2f310 100644 (file)
@@ -48,6 +48,7 @@ struct keccak_st {
     int xof_state;
 };
 
+KECCAK1600_CTX *ossl_shake256_new(void);
 void ossl_sha3_reset(KECCAK1600_CTX *ctx);
 int ossl_sha3_init(KECCAK1600_CTX *ctx, unsigned char pad, size_t bitlen);
 int ossl_keccak_init(KECCAK1600_CTX *ctx, unsigned char pad,