]> git.ipfire.org Git - thirdparty/nettle.git/commitdiff
wip slh_hash abstraction
authorNiels Möller <nisse@lysator.liu.se>
Wed, 2 Jul 2025 14:31:30 +0000 (16:31 +0200)
committerNiels Möller <nisse@lysator.liu.se>
Sat, 5 Jul 2025 08:14:55 +0000 (10:14 +0200)
Makefile.in
slh-dsa-internal.h
slh-dsa-shake-128f.c
slh-dsa-shake-128s.c
slh-dsa.c
slh-fors.c
slh-sha256.c [new file with mode: 0644]
slh-shake.c
slh-wots.c
slh-xmss.c

index 88ecf211d4dd0405812814fec19f644cd31c9752..da379772e61d9272a73d942f61ed479318f2e865 100644 (file)
@@ -176,7 +176,7 @@ nettle_SOURCES = aes-decrypt-internal.c aes-decrypt-table.c \
                 yarrow256.c yarrow_key_event.c \
                 xts.c xts-aes128.c xts-aes256.c \
                 drbg-ctr-aes256.c \
-                slh-fors.c slh-merkle.c slh-shake.c slh-wots.c slh-xmss.c \
+                slh-fors.c slh-merkle.c slh-shake.c slh-sha256.c slh-wots.c slh-xmss.c \
                 slh-dsa.c slh-dsa-shake-128s.c slh-dsa-shake-128f.c
 
 hogweed_SOURCES = sexp.c sexp-format.c \
index 54e8499e2be1599669c3d85a62a6869ddd5b6e62..1fd77c084c5e50df7f6fa164fdd92d22e9c893db 100644 (file)
 
 #include <stdint.h>
 
+#include "nettle-types.h"
+
 /* Name mangling */
 #define _slh_shake_init _nettle_slh_shake_init
 #define _slh_shake _nettle_slh_shake
+#define _slh_shake_digest _nettle_slh_shake_digest
+#define _slh_shake_randomizer _nettle_slh_shake_randomizer
+#define _slh_shake_msg_digest _nettle_slh_shake_msg_digest
+#define _slh_sha256_init _nettle_slh_sha256_init
+#define _slh_sha256 _nettle_slh_sha256
+#define _slh_sha256_randomizer _nettle_slh_sha256_randomizer
+#define _slh_sha256_digest _nettle_slh_sha256_digest
 #define _wots_gen _nettle_wots_gen
 #define _wots_sign _nettle_wots_sign
 #define _wots_verify _nettle_wots_verify
@@ -49,8 +58,6 @@
 #define _xmss_gen _nettle_xmss_gen
 #define _xmss_sign _nettle_xmss_sign
 #define _xmss_verify _nettle_xmss_verify
-#define _slh_dsa_randomizer _nettle_slh_dsa_randomizer
-#define _slh_dsa_digest _nettle_slh_dsa_digest
 #define _slh_dsa_sign _nettle_slh_dsa_sign
 #define _slh_dsa_verify _nettle_slh_dsa_verify
 
@@ -82,11 +89,40 @@ enum slh_addr_type
     SLH_FORS_PRF = 6,
   };
 
-struct sha3_ctx;
-struct slh_merkle_ctx_public
+typedef void slh_hash_init_func (void *tree_ctx, const uint8_t *public_seed,
+                                uint32_t layer, uint64_t tree_idx);
+typedef void slh_hash_secret_func (const void *tree_ctx,
+                                  const struct slh_address_hash *ah,
+                                  const uint8_t *secret, uint8_t *out);
+typedef void slh_hash_node_func (const void *tree_ctx,
+                                const struct slh_address_hash *ah,
+                                const uint8_t *left, const uint8_t *right,
+                                uint8_t *out);
+typedef void slh_hash_start_func (const void *tree_ctx, void *ctx, const struct slh_address_hash *ah);
+
+struct slh_hash
 {
+  slh_hash_init_func *init;
+  slh_hash_secret_func *secret;
+  slh_hash_node_func *node;
+  slh_hash_start_func *start;
+  nettle_hash_update_func *update;
+  nettle_hash_digest_func *digest;
+};
+
+extern const struct slh_hash _slh_hash_shake;
+struct slh_hash_ctxs
+{
+  const struct slh_hash *hash;
   /* Initialized based on public seed and slh_address_tree. */
-  const struct sha3_ctx *tree_ctx;
+  const void *tree;
+  /* Working ctx for wots and fors. */
+  void *scratch;
+};
+
+struct slh_merkle_ctx_public
+{
+  struct slh_hash_ctxs ctx;
   unsigned keypair; /* Used only by fors_leaf and fors_node. */
 };
 
@@ -107,7 +143,6 @@ struct slh_fors_params
 {
   unsigned short a; /* Height of tree. */
   unsigned short k; /* Number of trees. */
-  unsigned short msg_size;
   unsigned short signature_size;
 };
 
@@ -117,8 +152,9 @@ struct slh_dsa_params
   struct slh_fors_params fors;
 };
 
-extern const struct slh_dsa_params _slh_dsa_shake_128s_params;
+extern const struct slh_dsa_params _slh_dsa_128s_params;
 
+struct sha3_ctx;
 void
 _slh_shake_init (struct sha3_ctx *ctx, const uint8_t *public_seed,
                 uint32_t layer, uint64_t tree_idx);
@@ -128,25 +164,57 @@ _slh_shake (const struct sha3_ctx *tree_ctx,
            const struct slh_address_hash *ah,
            const uint8_t *secret, uint8_t *out);
 
+void
+_slh_shake_digest (struct sha3_ctx *ctx, uint8_t *out);
+
+void
+_slh_shake_randomizer (const uint8_t *public_seed, const uint8_t *secret_prf,
+                      size_t msg_length, const uint8_t *msg,
+                      uint8_t *randomizer);
+void
+_slh_shake_msg_digest (const uint8_t *randomizer, const uint8_t *pub,
+                      size_t length, const uint8_t *msg,
+                      size_t digest_size, uint8_t *digest);
+
+struct sha256_ctx;
+void
+_slh_sha256_init (struct sha256_ctx *ctx, const uint8_t *public_seed,
+                uint32_t layer, uint64_t tree_idx);
+
+void
+_slh_sha256 (const struct sha256_ctx *tree_ctx,
+            const struct slh_address_hash *ah,
+            const uint8_t *secret, uint8_t *out);
+
+void
+_slh_sha256_randomizer (const uint8_t *public_seed, const uint8_t *secret_prf,
+                       size_t msg_length, const uint8_t *msg,
+                       uint8_t *randomizer);
+void
+_slh_sha256_msg_digest (const uint8_t *randomizer, const uint8_t *pub,
+                       size_t length, const uint8_t *msg,
+                       size_t digest_size, uint8_t *digest);
+
 #define _WOTS_SIGNATURE_LENGTH 35
 /* 560 bytes */
 #define WOTS_SIGNATURE_SIZE (_WOTS_SIGNATURE_LENGTH*_SLH_DSA_128_SIZE)
 
 void
-_wots_gen (const struct sha3_ctx *tree_ctx, const uint8_t *secret_seed,
+_wots_gen (const struct slh_hash_ctxs *ctx, const uint8_t *secret_seed,
           uint32_t keypair, uint8_t *pub);
 
 void
-_wots_sign (const struct sha3_ctx *tree_ctx, const uint8_t *secret_seed,
+_wots_sign (const struct slh_hash_ctxs *ctx, const uint8_t *secret_seed,
            unsigned keypair, const uint8_t *msg, uint8_t *signature, uint8_t *pub);
 
 /* Computes candidate public key from signature. */
 void
-_wots_verify (const struct sha3_ctx *tree_ctx,
+_wots_verify (struct slh_hash_ctxs *ctx,
              unsigned keypair, const uint8_t *msg, const uint8_t *signature, uint8_t *pub);
 
-/* Merkle tree functions. Could be generalized for other merkle tree
-   applications, by using const void* for the ctx argument. */
+/* Merkle tree functions. Leaf function uses a non-const context, to allow the ctx to point at
+   working storage. Could be generalized for other merkle tree
+   applications, by using void * for the ctx argument. */
 typedef void merkle_leaf_hash_func (const struct slh_merkle_ctx_secret *ctx, unsigned index, uint8_t *out);
 typedef void merkle_node_hash_func (const struct slh_merkle_ctx_public *ctx, unsigned height, unsigned index,
                                    const uint8_t *left, const uint8_t *right, uint8_t *out);
@@ -192,7 +260,9 @@ _fors_verify (const struct slh_merkle_ctx_public *ctx,
 
 /* Provided scratch must be of size (xmss->h + 1) * _SLH_DSA_128_SIZE. */
 void
-_xmss_gen (const uint8_t *public_seed, const uint8_t *secret_seed,
+_xmss_gen (const struct slh_hash *hash,
+          void *ha, void *hb,
+          const uint8_t *public_seed, const uint8_t *secret_seed,
           const struct slh_xmss_params *xmss,
           uint8_t *scratch, uint8_t *root);
 
@@ -207,21 +277,18 @@ _xmss_verify (const struct slh_merkle_ctx_public *ctx, unsigned h,
              unsigned idx, const uint8_t *msg, const uint8_t *signature, uint8_t *pub);
 
 void
-_slh_dsa_randomizer (const uint8_t *public_seed, const uint8_t *secret_prf,
-                    size_t msg_length, const uint8_t *msg,
-                    uint8_t *randomizer);
-void
-_slh_dsa_digest (const uint8_t *randomizer, const uint8_t *pub,
-                size_t length, const uint8_t *msg,
-                size_t digest_size, uint8_t *digest);
-void
 _slh_dsa_sign (const struct slh_dsa_params *params,
+              const struct slh_hash *hash,
+              void *ha, void *hb,
               const uint8_t *pub, const uint8_t *priv,
               const uint8_t *digest,
               uint64_t tree_idx, unsigned leaf_idx,
               uint8_t *signature);
 int
-_slh_dsa_verify (const struct slh_dsa_params *params, const uint8_t *pub,
+_slh_dsa_verify (const struct slh_dsa_params *params,
+                const struct slh_hash *hash,
+                void *ha, void *hb,
+                const uint8_t *pub,
                 const uint8_t *digest, uint64_t tree_idx, unsigned leaf_idx,
                 const uint8_t *signature);
 
index cc76d2b9934d9439741c05b959ad80399723af7d..3d7da32e976538fe8045012b7e902e4f067c70fc 100644 (file)
@@ -38,6 +38,8 @@
 #include "slh-dsa.h"
 #include "slh-dsa-internal.h"
 
+#include "sha3.h"
+
 #define SLH_DSA_M 34
 
 #define SLH_DSA_D 22
@@ -53,7 +55,7 @@ const struct slh_dsa_params
 _slh_dsa_shake_128f_params =
   {
     { SLH_DSA_D, XMSS_H, XMSS_SIGNATURE_SIZE (XMSS_H) },
-    { FORS_A, FORS_K, FORS_MSG_SIZE, FORS_SIGNATURE_SIZE (FORS_A, FORS_K) },
+    { FORS_A, FORS_K, FORS_SIGNATURE_SIZE (FORS_A, FORS_K) },
   };
 
 void
@@ -61,7 +63,9 @@ slh_dsa_shake_128f_root (const uint8_t *public_seed, const uint8_t *private_seed
                         uint8_t *root)
 {
   uint8_t scratch[(XMSS_H + 1)*_SLH_DSA_128_SIZE];
-  _xmss_gen (public_seed, private_seed, &_slh_dsa_shake_128f_params.xmss, scratch, root);
+  struct sha3_ctx ha, hb;
+  _xmss_gen (&_slh_hash_shake, &ha, &hb, public_seed, private_seed,
+            &_slh_dsa_shake_128f_params.xmss, scratch, root);
 }
 
 void
@@ -109,12 +113,15 @@ slh_dsa_shake_128f_sign (const uint8_t *pub, const uint8_t *priv,
   uint8_t digest[SLH_DSA_M];
   uint64_t tree_idx;
   unsigned leaf_idx;
+  struct sha3_ctx ha, hb;
 
-  _slh_dsa_randomizer (pub, priv + _SLH_DSA_128_SIZE, length, msg, signature);
-  _slh_dsa_digest (signature, pub, length, msg, SLH_DSA_M, digest);
+  _slh_shake_randomizer (pub, priv + _SLH_DSA_128_SIZE, length, msg, signature);
+  _slh_shake_msg_digest (signature, pub, length, msg, SLH_DSA_M, digest);
   parse_digest (digest + FORS_MSG_SIZE, &tree_idx, &leaf_idx);
 
-  _slh_dsa_sign (&_slh_dsa_shake_128f_params, pub, priv, digest, tree_idx, leaf_idx,
+  _slh_dsa_sign (&_slh_dsa_shake_128f_params,
+                &_slh_hash_shake, &ha, &hb,
+                pub, priv, digest, tree_idx, leaf_idx,
                 signature + _SLH_DSA_128_SIZE);
 }
 
@@ -126,9 +133,12 @@ slh_dsa_shake_128f_verify (const uint8_t *pub,
   uint8_t digest[SLH_DSA_M];
   uint64_t tree_idx;
   unsigned leaf_idx;
+  struct sha3_ctx ha, hb;
 
-  _slh_dsa_digest (signature, pub, length, msg, SLH_DSA_M,digest);
+  _slh_shake_msg_digest (signature, pub, length, msg, SLH_DSA_M,digest);
   parse_digest (digest + FORS_MSG_SIZE, &tree_idx, &leaf_idx);
-  return _slh_dsa_verify (&_slh_dsa_shake_128f_params, pub, digest, tree_idx, leaf_idx,
+  return _slh_dsa_verify (&_slh_dsa_shake_128f_params,
+                         &_slh_hash_shake, &ha, &hb,
+                         pub, digest, tree_idx, leaf_idx,
                          signature + _SLH_DSA_128_SIZE);
 }
index 510946ca35a515e7d1590bed7e8e97b4f4ee8228..f34c3bd26c1e06e5a2ffcd1dbef802924d1bfd07 100644 (file)
@@ -38,6 +38,8 @@
 #include "slh-dsa.h"
 #include "slh-dsa-internal.h"
 
+#include "sha3.h"
+
 #define SLH_DSA_M 30
 
 #define SLH_DSA_D 7
@@ -53,7 +55,7 @@ const struct slh_dsa_params
 _slh_dsa_shake_128s_params =
   {
     { SLH_DSA_D, XMSS_H, XMSS_SIGNATURE_SIZE (XMSS_H) },
-    { FORS_A, FORS_K, FORS_MSG_SIZE, FORS_SIGNATURE_SIZE (FORS_A, FORS_K) },
+    { FORS_A, FORS_K, FORS_SIGNATURE_SIZE (FORS_A, FORS_K) },
   };
 
 void
@@ -61,7 +63,9 @@ slh_dsa_shake_128s_root (const uint8_t *public_seed, const uint8_t *private_seed
                         uint8_t *root)
 {
   uint8_t scratch[(XMSS_H + 1)*_SLH_DSA_128_SIZE];
-  _xmss_gen (public_seed, private_seed, &_slh_dsa_shake_128s_params.xmss, scratch, root);
+  struct sha3_ctx ha, hb;
+  _xmss_gen (&_slh_hash_shake, &ha, &hb, public_seed, private_seed,
+            &_slh_dsa_shake_128s_params.xmss, scratch, root);
 }
 
 void
@@ -109,12 +113,15 @@ slh_dsa_shake_128s_sign (const uint8_t *pub, const uint8_t *priv,
   uint8_t digest[SLH_DSA_M];
   uint64_t tree_idx;
   unsigned leaf_idx;
+  struct sha3_ctx ha, hb;
 
-  _slh_dsa_randomizer (pub, priv + _SLH_DSA_128_SIZE, length, msg, signature);
-  _slh_dsa_digest (signature, pub, length, msg, SLH_DSA_M, digest);
+  _slh_shake_randomizer (pub, priv + _SLH_DSA_128_SIZE, length, msg, signature);
+  _slh_shake_msg_digest (signature, pub, length, msg, SLH_DSA_M, digest);
   parse_digest (digest + FORS_MSG_SIZE, &tree_idx, &leaf_idx);
 
-  _slh_dsa_sign (&_slh_dsa_shake_128s_params, pub, priv, digest, tree_idx, leaf_idx,
+  _slh_dsa_sign (&_slh_dsa_shake_128s_params,
+                &_slh_hash_shake, &ha, &hb,
+                pub, priv, digest, tree_idx, leaf_idx,
                 signature + _SLH_DSA_128_SIZE);
 }
 
@@ -126,9 +133,12 @@ slh_dsa_shake_128s_verify (const uint8_t *pub,
   uint8_t digest[SLH_DSA_M];
   uint64_t tree_idx;
   unsigned leaf_idx;
+  struct sha3_ctx ha, hb;
 
-  _slh_dsa_digest (signature, pub, length, msg, SLH_DSA_M,digest);
+  _slh_shake_msg_digest (signature, pub, length, msg, SLH_DSA_M,digest);
   parse_digest (digest + FORS_MSG_SIZE, &tree_idx, &leaf_idx);
-  return _slh_dsa_verify (&_slh_dsa_shake_128s_params, pub, digest, tree_idx, leaf_idx,
+  return _slh_dsa_verify (&_slh_dsa_shake_128s_params,
+                         &_slh_hash_shake, &ha, &hb,
+                         pub, digest, tree_idx, leaf_idx,
                          signature + _SLH_DSA_128_SIZE);
 }
index d5b0afab7f29486733b2847d47b306ae50bf4302..312b7aa050e0f1512238bad791e9a0b4f4c9bdbb 100644 (file)
--- a/slh-dsa.c
+++ b/slh-dsa.c
 #include <string.h>
 
 #include "memops.h"
-#include "sha3.h"
 #include "slh-dsa.h"
 #include "slh-dsa-internal.h"
 
+#if 0
+/* For 128s flavor. */
+#define SLH_DSA_M 30
 
-static const uint8_t slh_pure_prefix[2] = {0, 0};
+#define SLH_DSA_D 7
+#define XMSS_H 9
 
-void
-_slh_dsa_randomizer (const uint8_t *public_seed, const uint8_t *secret_prf,
-                    size_t msg_length, const uint8_t *msg,
-                    uint8_t *randomizer)
-{
-  struct sha3_ctx ctx;
-
-  sha3_init (&ctx);
-  sha3_256_update (&ctx, _SLH_DSA_128_SIZE, secret_prf);
-  sha3_256_update (&ctx, _SLH_DSA_128_SIZE, public_seed);
-  sha3_256_update (&ctx, sizeof (slh_pure_prefix), slh_pure_prefix);
-  sha3_256_update (&ctx, msg_length, msg);
-  sha3_256_shake (&ctx, _SLH_DSA_128_SIZE, randomizer);
-}
+/* Use k Merkle trees, each of size 2^a. Signs messages of size
+   k * a = 168 bits or 21 octets. */
+#define FORS_A 12
+#define FORS_K 14
 
-void
-_slh_dsa_digest (const uint8_t *randomizer, const uint8_t *pub,
-                size_t length, const uint8_t *msg,
-                size_t digest_size, uint8_t *digest)
-{
-  struct sha3_ctx ctx;
-
-  sha3_init (&ctx);
-  sha3_256_update (&ctx, _SLH_DSA_128_SIZE, randomizer);
-  sha3_256_update (&ctx, 2*_SLH_DSA_128_SIZE, pub);
-  sha3_256_update (&ctx, sizeof (slh_pure_prefix), slh_pure_prefix);
-  sha3_256_update (&ctx, length, msg);
-  sha3_256_shake (&ctx, digest_size, digest);
-}
+const struct slh_dsa_params
+_slh_dsa_128s_params =
+  {
+    { SLH_DSA_D, XMSS_H, XMSS_SIGNATURE_SIZE (XMSS_H) },
+    { FORS_A, FORS_K, FORS_SIGNATURE_SIZE (FORS_A, FORS_K) },
+  };
+#endif
 
 void
 _slh_dsa_sign (const struct slh_dsa_params *params,
+              const struct slh_hash *hash,
+              void *ha, void *hb,
               const uint8_t *pub, const uint8_t *priv,
               const uint8_t *digest,
               uint64_t tree_idx, unsigned leaf_idx,
@@ -86,13 +74,12 @@ _slh_dsa_sign (const struct slh_dsa_params *params,
   uint8_t root[_SLH_DSA_128_SIZE];
   int i;
 
-  struct sha3_ctx tree_ctx;
   struct slh_merkle_ctx_secret merkle_ctx =
     {
-      { &tree_ctx, leaf_idx },
+      { { hash, ha, hb }, leaf_idx },
       priv,
     };
-  _slh_shake_init (&tree_ctx, pub, 0, tree_idx);
+  hash->init(ha, pub, 0, tree_idx);
 
   _fors_sign (&merkle_ctx, &params->fors, digest, signature, root);
   signature += params->fors.signature_size;
@@ -106,7 +93,7 @@ _slh_dsa_sign (const struct slh_dsa_params *params,
       leaf_idx = tree_idx & ((1 << params->xmss.h) - 1);
       tree_idx >>= params->xmss.h;
 
-      _slh_shake_init (&tree_ctx, pub, i, tree_idx);
+      hash->init(ha, pub, i, tree_idx);
 
       _xmss_sign (&merkle_ctx, params->xmss.h, leaf_idx, root, signature, root);
     }
@@ -114,18 +101,20 @@ _slh_dsa_sign (const struct slh_dsa_params *params,
 }
 
 int
-_slh_dsa_verify (const struct slh_dsa_params *params, const uint8_t *pub,
+_slh_dsa_verify (const struct slh_dsa_params *params,
+                const struct slh_hash *hash,
+                void *ha, void *hb,
+                const uint8_t *pub,
                 const uint8_t *digest, uint64_t tree_idx, unsigned leaf_idx,
                 const uint8_t *signature)
 {
   uint8_t root[_SLH_DSA_128_SIZE];
   int i;
 
-  struct sha3_ctx tree_ctx;
   struct slh_merkle_ctx_public merkle_ctx =
-    { &tree_ctx, leaf_idx };
+    { { hash, ha, hb }, leaf_idx };
 
-  _slh_shake_init (&tree_ctx, pub, 0, tree_idx);
+  hash->init(ha, pub, 0, tree_idx);
 
   _fors_verify (&merkle_ctx, &params->fors, digest, signature, root);
   signature += params->fors.signature_size;
@@ -139,7 +128,7 @@ _slh_dsa_verify (const struct slh_dsa_params *params, const uint8_t *pub,
       leaf_idx = tree_idx & ((1 << params->xmss.h) - 1);
       tree_idx >>= params->xmss.h;
 
-      _slh_shake_init (&tree_ctx, pub, i, tree_idx);
+      hash->init(ha, pub, i, tree_idx);
 
       _xmss_verify (&merkle_ctx, params->xmss.h, leaf_idx, root, signature, root);
     }
index d3fc452c5fc11ff9dc514ee2d24c8b8784a4f7ac..28eff50cff1b00b74b60f48dd228e332346139c0 100644 (file)
@@ -41,6 +41,7 @@
 #include "sha3.h"
 #include "slh-dsa-internal.h"
 
+/* TODO: Like wots_gen, take hash ctx, secret seed (+ keypair) as separate arguments? */
 void
 _fors_gen (const struct slh_merkle_ctx_secret *ctx,
           unsigned idx, uint8_t *sk, uint8_t *leaf)
@@ -53,10 +54,10 @@ _fors_gen (const struct slh_merkle_ctx_secret *ctx,
       bswap32_if_le (idx),
     };
 
-  _slh_shake (ctx->pub.tree_ctx, &ah, ctx->secret_seed, sk);
+  ctx->pub.ctx.hash->secret (ctx->pub.ctx.tree, &ah, ctx->secret_seed, sk);
 
   ah.type = bswap32_if_le (SLH_FORS_TREE);
-  _slh_shake (ctx->pub.tree_ctx, &ah, sk, leaf);
+  ctx->pub.ctx.hash->secret (ctx->pub.ctx.tree, &ah, ctx->secret_seed, leaf);
 }
 
 static void
@@ -69,7 +70,6 @@ static void
 fors_node (const struct slh_merkle_ctx_public *ctx, unsigned height, unsigned index,
           const uint8_t *left, const uint8_t *right, uint8_t *out)
 {
-  struct sha3_ctx sha3 = *ctx->tree_ctx;
   struct slh_address_hash ah =
     {
       bswap32_if_le (SLH_FORS_TREE),
@@ -77,15 +77,12 @@ fors_node (const struct slh_merkle_ctx_public *ctx, unsigned height, unsigned in
       bswap32_if_le (height),
       bswap32_if_le (index),
     };
-  sha3_256_update (&sha3, sizeof (ah), (const uint8_t *) &ah);
-  sha3_256_update (&sha3, _SLH_DSA_128_SIZE, left);
-  sha3_256_update (&sha3, _SLH_DSA_128_SIZE, right);
-  sha3_256_shake (&sha3, _SLH_DSA_128_SIZE, out);
+  ctx->ctx.hash->node (ctx->ctx.tree, &ah, left, right, out);
 }
 
 static void
 fors_sign_one (const struct slh_merkle_ctx_secret *ctx, unsigned a,
-              unsigned idx, uint8_t *signature, struct sha3_ctx *pub)
+              unsigned idx, uint8_t *signature)
 {
   uint8_t hash[_SLH_DSA_128_SIZE];
 
@@ -95,7 +92,7 @@ fors_sign_one (const struct slh_merkle_ctx_secret *ctx, unsigned a,
                signature + _SLH_DSA_128_SIZE);
   _merkle_verify (&ctx->pub, fors_node, a, idx, signature + _SLH_DSA_128_SIZE, hash);
 
-  sha3_256_update (pub, _SLH_DSA_128_SIZE, hash);
+  ctx->pub.ctx.hash->update (ctx->pub.ctx.scratch, _SLH_DSA_128_SIZE, hash);
 }
 
 void
@@ -109,11 +106,10 @@ _fors_sign (const struct slh_merkle_ctx_secret *ctx,
       bswap32_if_le (ctx->pub.keypair),
       0, 0,
     };
-  struct sha3_ctx sha3 = *ctx->pub.tree_ctx;
   unsigned i, w, bits;
   unsigned mask = (1 << fors->a) - 1;
 
-  sha3_256_update (&sha3, sizeof (ah), (const uint8_t *) &ah);
+  ctx->pub.ctx.hash->start (ctx->pub.ctx.tree, ctx->pub.ctx.scratch, &ah);
 
   for (i = w = bits = 0; i < fors->k; i++, signature += (fors->a + 1) * _SLH_DSA_128_SIZE)
     {
@@ -121,15 +117,15 @@ _fors_sign (const struct slh_merkle_ctx_secret *ctx,
        w = (w << 8) | *msg++;
       bits -= fors->a;
 
-      fors_sign_one (ctx, fors->a, (i << fors->a) + ((w >> bits) & mask), signature, &sha3);
+      fors_sign_one (ctx, fors->a, (i << fors->a) + ((w >> bits) & mask), signature);
      }
 
-  sha3_256_shake (&sha3, _SLH_DSA_128_SIZE, pub);
+  ctx->pub.ctx.hash->digest (ctx->pub.ctx.scratch, pub);
 }
 
 static void
 fors_verify_one (const struct slh_merkle_ctx_public *ctx, unsigned a,
-                unsigned idx, const uint8_t *signature, struct sha3_ctx *pub)
+                unsigned idx, const uint8_t *signature)
 {
   uint8_t root[_SLH_DSA_128_SIZE];
   struct slh_address_hash ah =
@@ -140,10 +136,10 @@ fors_verify_one (const struct slh_merkle_ctx_public *ctx, unsigned a,
       bswap32_if_le (idx),
     };
 
-  _slh_shake (ctx->tree_ctx, &ah, signature, root);
+  ctx->ctx.hash->secret (ctx->ctx.tree, &ah, signature, root);
   _merkle_verify (ctx, fors_node, a, idx, signature + _SLH_DSA_128_SIZE, root);
 
-  sha3_256_update (pub, _SLH_DSA_128_SIZE, root);
+  ctx->ctx.hash->update (ctx->ctx.scratch, _SLH_DSA_128_SIZE, root);
 }
 
 void
@@ -151,17 +147,16 @@ _fors_verify (const struct slh_merkle_ctx_public *ctx,
              const struct slh_fors_params *fors,
              const uint8_t *msg, const uint8_t *signature, uint8_t *pub)
 {
-  struct sha3_ctx sha3 = *ctx->tree_ctx;
-  unsigned i, w, bits;
-  unsigned mask = (1 << fors->a) - 1;
   struct slh_address_hash ah =
     {
       bswap32_if_le (SLH_FORS_ROOTS),
       bswap32_if_le (ctx->keypair),
       0, 0,
     };
+  unsigned i, w, bits;
+  unsigned mask = (1 << fors->a) - 1;
 
-  sha3_256_update (&sha3, sizeof (ah), (const uint8_t *) &ah);
+  ctx->ctx.hash->start (ctx->ctx.tree, ctx->ctx.scratch, &ah);
 
   for (i = w = bits = 0; i < fors->k; i++, signature += (fors->a + 1) * _SLH_DSA_128_SIZE)
     {
@@ -169,7 +164,7 @@ _fors_verify (const struct slh_merkle_ctx_public *ctx,
        w = (w << 8) | *msg++;
       bits -= fors->a;
 
-      fors_verify_one (ctx, fors->a, (i << fors->a) + ((w >> bits) & mask), signature, &sha3);
+      fors_verify_one (ctx, fors->a, (i << fors->a) + ((w >> bits) & mask), signature);
     }
-  sha3_256_shake (&sha3, _SLH_DSA_128_SIZE, pub);
+  ctx->ctx.hash->digest (ctx->ctx.scratch, pub);
 }
diff --git a/slh-sha256.c b/slh-sha256.c
new file mode 100644 (file)
index 0000000..d19749c
--- /dev/null
@@ -0,0 +1,135 @@
+/* slh-sha256.c
+
+   Copyright (C) 2025 Niels Möller
+
+   This file is part of GNU Nettle.
+
+   GNU Nettle is free software: you can redistribute it and/or
+   modify it under the terms of either:
+
+     * the GNU Lesser General Public License as published by the Free
+       Software Foundation; either version 3 of the License, or (at your
+       option) any later version.
+
+   or
+
+     * the GNU General Public License as published by the Free
+       Software Foundation; either version 2 of the License, or (at your
+       option) any later version.
+
+   or both in parallel, as here.
+
+   GNU Nettle is distributed in the hope that it will be useful,
+   but WITHOUT ANY WARRANTY; without even the implied warranty of
+   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+   General Public License for more details.
+
+   You should have received copies of the GNU General Public License and
+   the GNU Lesser General Public License along with this program.  If
+   not, see http://www.gnu.org/licenses/.
+*/
+
+#if HAVE_CONFIG_H
+# include "config.h"
+#endif
+
+#include <string.h>
+
+#include "slh-dsa-internal.h"
+
+#include "bswap-internal.h"
+#include "hmac.h"
+#include "sha2.h"
+
+/* Uses a "compressed" address,
+
+     uint8_t layer
+     uint64_t tree_idx
+     uint8_t type
+
+   (packed, no padding).
+*/
+
+/* All hashing but H_{msg} and PRF_{msg} use plain sha256: */
+void
+_slh_sha256_init (struct sha256_ctx *ctx, const uint8_t *public_seed,
+                 uint32_t layer, uint64_t tree_idx)
+{
+  static const uint8_t pad[48];
+  uint8_t addr_layer;
+  uint64_t addr_tree;
+
+  sha256_init (ctx);
+  sha256_update (ctx, _SLH_DSA_128_SIZE, public_seed);
+  /* This padding completes a sha256 block. */
+  sha256_update (ctx, sizeof(pad), pad);
+  /* Compressed address. */
+  addr_layer = layer;
+  sha256_update (ctx, 1, &addr_layer);
+  addr_tree = bswap64_if_le (tree_idx);
+  sha256_update (ctx, sizeof(addr_tree), (const uint8_t *) &addr_tree);
+}
+
+void
+_slh_sha256 (const struct sha256_ctx *tree_ctx,
+            const struct slh_address_hash *ah,
+            const uint8_t *secret, uint8_t *out)
+{
+  struct sha256_ctx ctx = *tree_ctx;
+  uint8_t digest[SHA256_DIGEST_SIZE];
+  /* For compressed addr, hash only last byte of the type. */
+  sha256_update (&ctx, sizeof(*ah) - 3, (const uint8_t *) ah + 3);
+  sha256_update (&ctx, _SLH_DSA_128_SIZE, secret);
+  sha256_digest (&ctx, digest);
+  memcpy (out, digest, _SLH_DSA_128_SIZE);
+}
+
+static const uint8_t slh_pure_prefix[2] = {0, 0};
+
+void
+_slh_sha256_randomizer (const uint8_t *public_seed, const uint8_t *secret_prf,
+                      size_t msg_length, const uint8_t *msg,
+                      uint8_t *randomizer)
+{
+  struct hmac_sha256_ctx ctx;
+  uint8_t digest[SHA256_DIGEST_SIZE];
+  hmac_sha256_set_key (&ctx, _SLH_DSA_128_SIZE, secret_prf);
+  hmac_sha256_update (&ctx, _SLH_DSA_128_SIZE, public_seed);
+  hmac_sha256_update (&ctx, sizeof (slh_pure_prefix), slh_pure_prefix);
+  hmac_sha256_update (&ctx, msg_length, msg);
+  hmac_sha256_digest (&ctx, digest);
+  memcpy (randomizer, digest, _SLH_DSA_128_SIZE);
+}
+
+void
+_slh_sha256_msg_digest (const uint8_t *randomizer, const uint8_t *pub,
+                       size_t length, const uint8_t *msg,
+                       size_t digest_size, uint8_t *digest)
+{
+  struct sha256_ctx ctx;
+  uint8_t inner[SHA256_DIGEST_SIZE];
+  uint32_t i;
+  sha256_init (&ctx);
+  sha256_update (&ctx, _SLH_DSA_128_SIZE, randomizer);
+  sha256_update (&ctx, 2*_SLH_DSA_128_SIZE, pub);
+  sha256_update (&ctx, sizeof (slh_pure_prefix), slh_pure_prefix);
+  sha256_update (&ctx, length, msg);
+  sha256_digest (&ctx, inner);
+
+  /* mgf1 with inner digest as the seed. */
+  for (i = 0; digest_size > 0; i++)
+    {
+      uint32_t i_be = bswap32_if_le (i);
+      sha256_update (&ctx, sizeof (inner), inner);
+      sha256_update (&ctx, sizeof (i_be), (const uint8_t *) &i_be);
+      if (digest_size < SHA256_DIGEST_SIZE)
+       {
+         sha256_digest (&ctx, inner);
+         memcpy (digest, inner, digest_size);
+         break;
+       }
+      sha256_digest (&ctx, digest);
+      digest += SHA256_DIGEST_SIZE;
+      digest_size -= SHA256_DIGEST_SIZE;
+    }
+}
index bc2698ccaa9a997a0adb2297cf0805e891a2835c..8a97168a13e04e57de37a8020e7dfb0ce109b914 100644 (file)
@@ -57,12 +57,80 @@ _slh_shake_init (struct sha3_ctx *ctx, const uint8_t *public_seed,
   sha3_256_update (ctx, sizeof (at), (const uint8_t *) &at);
 }
 
+static void
+slh_shake_start (const struct sha3_ctx *tree_ctx, struct sha3_ctx *ctx,
+                const struct slh_address_hash *ah)
+{
+  *ctx = *tree_ctx;
+  sha3_256_update (ctx, sizeof (*ah), (const uint8_t *) ah);
+}
+
 void
 _slh_shake (const struct sha3_ctx *tree_ctx, const struct slh_address_hash *ah,
            const uint8_t *secret, uint8_t *out)
 {
-  struct sha3_ctx ctx = *tree_ctx;
-  sha3_256_update (&ctx, sizeof (*ah), (const uint8_t *) ah);
+  struct sha3_ctx ctx;
+  slh_shake_start (tree_ctx, &ctx, ah);;
   sha3_256_update (&ctx, _SLH_DSA_128_SIZE, secret);
   sha3_256_shake (&ctx, _SLH_DSA_128_SIZE, out);
 }
+
+static void
+slh_shake_node (const struct sha3_ctx *tree_ctx, const struct slh_address_hash *ah,
+                const uint8_t *left, const uint8_t *right, uint8_t *out)
+{
+  struct sha3_ctx ctx;
+  slh_shake_start (tree_ctx, &ctx, ah);;
+  sha3_256_update (&ctx, _SLH_DSA_128_SIZE, left);
+  sha3_256_update (&ctx, _SLH_DSA_128_SIZE, right);
+  sha3_256_shake (&ctx, _SLH_DSA_128_SIZE, out);
+}
+
+void
+_slh_shake_digest (struct sha3_ctx *ctx, uint8_t *out)
+{
+  sha3_256_shake (ctx, _SLH_DSA_128_SIZE, out);
+}
+
+static const uint8_t slh_pure_prefix[2] = {0, 0};
+
+void
+_slh_shake_randomizer (const uint8_t *public_seed, const uint8_t *secret_prf,
+                      size_t msg_length, const uint8_t *msg,
+                      uint8_t *randomizer)
+{
+  struct sha3_ctx ctx;
+
+  sha3_init (&ctx);
+  sha3_256_update (&ctx, _SLH_DSA_128_SIZE, secret_prf);
+  sha3_256_update (&ctx, _SLH_DSA_128_SIZE, public_seed);
+  sha3_256_update (&ctx, sizeof (slh_pure_prefix), slh_pure_prefix);
+  sha3_256_update (&ctx, msg_length, msg);
+  sha3_256_shake (&ctx, _SLH_DSA_128_SIZE, randomizer);
+}
+
+void
+_slh_shake_msg_digest (const uint8_t *randomizer, const uint8_t *pub,
+                      size_t length, const uint8_t *msg,
+                      size_t digest_size, uint8_t *digest)
+{
+  struct sha3_ctx ctx;
+
+  sha3_init (&ctx);
+  sha3_256_update (&ctx, _SLH_DSA_128_SIZE, randomizer);
+  sha3_256_update (&ctx, 2*_SLH_DSA_128_SIZE, pub);
+  sha3_256_update (&ctx, sizeof (slh_pure_prefix), slh_pure_prefix);
+  sha3_256_update (&ctx, length, msg);
+  sha3_256_shake (&ctx, digest_size, digest);
+}
+
+const struct slh_hash
+_slh_hash_shake =
+  {
+    (slh_hash_init_func *) _slh_shake_init,
+    (slh_hash_secret_func *) _slh_shake,
+    (slh_hash_node_func *) slh_shake_node,
+    (slh_hash_start_func *) slh_shake_start,
+    (nettle_hash_update_func *) sha3_256_update,
+    (nettle_hash_digest_func *)_slh_shake_digest
+  };
index ec14c5158457b31f769daa30446a7124e15b5a9b..bd591af98b8d58b4ed847f6e09032a78ca56c238 100644 (file)
@@ -44,7 +44,7 @@
    dst. For the ah argument, leaves ah->keypair and ah->height_chain
    unchanged, but overwrites the other fields. */
 static const uint8_t *
-wots_chain (const struct sha3_ctx *ctx,
+wots_chain (const struct slh_hash_ctxs *ctx,
            struct slh_address_hash *ah,
            unsigned i, unsigned s,
            const uint8_t *src, uint8_t *dst)
@@ -57,38 +57,36 @@ wots_chain (const struct sha3_ctx *ctx,
   ah->type = bswap32_if_le (SLH_WOTS_HASH);
   ah->index_hash = bswap32_if_le (i);
 
-  _slh_shake (ctx, ah, src, dst);
+  ctx->hash->secret(ctx->tree, ah, src, dst);
 
   for (j = 1; j < s; j++)
     {
       ah->index_hash = bswap32_if_le (i + j);
-      _slh_shake (ctx, ah, dst, dst);
+      ctx->hash->secret (ctx->tree, ah, dst, dst);
     }
 
   return dst;
 }
 
 static void
-wots_pk_init (const struct sha3_ctx *tree_ctx,
-             unsigned keypair, struct slh_address_hash *ah, struct sha3_ctx *ctx)
+wots_pk_init (const struct slh_hash_ctxs *ctx,
+             unsigned keypair, struct slh_address_hash *ah)
 {
   ah->type = bswap32_if_le (SLH_WOTS_PK);
   ah->keypair = bswap32_if_le (keypair);
   ah->height_chain = 0;
   ah->index_hash = 0;
-  *ctx = *tree_ctx;
-  sha3_256_update (ctx, sizeof (*ah), (const uint8_t *) ah);
+  ctx->hash->start(ctx->tree, ctx->scratch, ah);
 }
 
 void
-_wots_gen (const struct sha3_ctx *tree_ctx, const uint8_t *secret_seed,
+_wots_gen (const struct slh_hash_ctxs *ctx, const uint8_t *secret_seed,
           uint32_t keypair, uint8_t *pub)
 {
   struct slh_address_hash ah;
-  struct sha3_ctx ctx;
   unsigned i;
 
-  wots_pk_init (tree_ctx, keypair, &ah, &ctx);
+  wots_pk_init (ctx, keypair, &ah);
 
   for (i = 0; i < _WOTS_SIGNATURE_LENGTH; i++)
     {
@@ -98,21 +96,21 @@ _wots_gen (const struct sha3_ctx *tree_ctx, const uint8_t *secret_seed,
       ah.type = bswap32_if_le (SLH_WOTS_PRF);
       ah.height_chain = bswap32_if_le (i);
       ah.index_hash = 0;
-      _slh_shake (tree_ctx, &ah, secret_seed, out);
+      ctx->hash->secret (ctx->tree, &ah, secret_seed, out);
 
       /* Hash chain. */
-      wots_chain (tree_ctx, &ah, 0, 15, out, out);
+      wots_chain (ctx, &ah, 0, 15, out, out);
 
-      sha3_256_update (&ctx, _SLH_DSA_128_SIZE, out);
+      ctx->hash->update (ctx->scratch, _SLH_DSA_128_SIZE, out);
     }
-  sha3_256_shake (&ctx, _SLH_DSA_128_SIZE, pub);
+  ctx->hash->digest (ctx->scratch, pub);
 }
 
 /* Produces signature hash corresponding to the ith message nybble. Modifies addr. */
 static void
-wots_sign_one (const struct sha3_ctx *tree_ctx, const uint8_t *secret_seed,
+wots_sign_one (const struct slh_hash_ctxs *ctx, const uint8_t *secret_seed,
               uint32_t keypair,
-              unsigned i, uint8_t msg, uint8_t *sig, struct sha3_ctx *ctx)
+              unsigned i, uint8_t msg, uint8_t *sig)
 {
   struct slh_address_hash ah;
   uint8_t pub[_SLH_DSA_128_SIZE];
@@ -123,47 +121,46 @@ wots_sign_one (const struct sha3_ctx *tree_ctx, const uint8_t *secret_seed,
   ah.keypair = bswap32_if_le (keypair);
   ah.height_chain = bswap32_if_le (i);
   ah.index_hash = 0;
-  _slh_shake (tree_ctx, &ah, secret_seed, sig);
+  ctx->hash->secret(ctx->tree, &ah, secret_seed, sig);
 
   /* Hash chain. */
-  wots_chain (tree_ctx, &ah, 0, msg, sig, sig);
+  wots_chain (ctx, &ah, 0, msg, sig, sig);
 
-  sha3_256_update (ctx, _SLH_DSA_128_SIZE,
-                  wots_chain (tree_ctx, &ah, msg, 15 - msg, sig, pub));
+  ctx->hash->update (ctx->scratch, _SLH_DSA_128_SIZE,
+                    wots_chain (ctx, &ah, msg, 15 - msg, sig, pub));
 }
 
 void
-_wots_sign (const struct sha3_ctx *tree_ctx, const uint8_t *secret_seed,
+_wots_sign (const struct slh_hash_ctxs *ctx, const uint8_t *secret_seed,
            unsigned keypair, const uint8_t *msg, uint8_t *signature, uint8_t *pub)
 {
   struct slh_address_hash ah;
-  struct sha3_ctx ctx;
   unsigned i;
   uint32_t csum;
 
-  wots_pk_init (tree_ctx, keypair, &ah, &ctx);
+  wots_pk_init (ctx, keypair, &ah);
 
   for (i = 0, csum = 15*32; i < _SLH_DSA_128_SIZE; i++)
     {
       uint8_t m0, m1;
       m0 = msg[i] >> 4;
       csum -= m0;
-      wots_sign_one (tree_ctx, secret_seed, keypair, 2*i, m0, signature, &ctx);
+      wots_sign_one (ctx, secret_seed, keypair, 2*i, m0, signature);
 
       m1 = msg[i] & 0xf;
       csum -= m1;
-      wots_sign_one (tree_ctx, secret_seed, keypair, 2*i + 1, m1, signature, &ctx);
+      wots_sign_one (ctx, secret_seed, keypair, 2*i + 1, m1, signature);
     }
 
-  wots_sign_one (tree_ctx, secret_seed, keypair, 32, csum >> 8, signature, &ctx);
-  wots_sign_one (tree_ctx, secret_seed, keypair, 33, (csum >> 4) & 0xf, signature, &ctx);
-  wots_sign_one (tree_ctx, secret_seed, keypair, 34, csum & 0xf, signature, &ctx);
+  wots_sign_one (ctx, secret_seed, keypair, 32, csum >> 8, signature);
+  wots_sign_one (ctx, secret_seed, keypair, 33, (csum >> 4) & 0xf, signature);
+  wots_sign_one (ctx, secret_seed, keypair, 34, csum & 0xf, signature);
 
-  sha3_256_shake (&ctx, _SLH_DSA_128_SIZE, pub);
+  ctx->hash->digest (ctx->scratch, pub);
 }
 
 static void
-wots_verify_one (const struct sha3_ctx *tree_ctx, struct sha3_ctx *ctx,
+wots_verify_one (struct slh_hash_ctxs *ctx,
                 uint32_t keypair, unsigned i, uint8_t msg, const uint8_t *signature)
 {
   struct slh_address_hash ah;
@@ -173,36 +170,35 @@ wots_verify_one (const struct sha3_ctx *tree_ctx, struct sha3_ctx *ctx,
   ah.keypair = bswap32_if_le (keypair);
   ah.height_chain = bswap32_if_le (i);
 
-  sha3_256_update (ctx, _SLH_DSA_128_SIZE,
-                  wots_chain (tree_ctx, &ah, msg, 15 - msg, signature, out));
+  ctx->hash->update(ctx->scratch, _SLH_DSA_128_SIZE,
+                   wots_chain (ctx, &ah, msg, 15 - msg, signature, out));
 }
 
 void
-_wots_verify (const struct sha3_ctx *tree_ctx,
+_wots_verify (struct slh_hash_ctxs *ctx,
              unsigned keypair, const uint8_t *msg, const uint8_t *signature, uint8_t *pub)
 {
   struct slh_address_hash ah;
-  struct sha3_ctx ctx;
   unsigned i;
   uint32_t csum;
 
-  wots_pk_init (tree_ctx, keypair, &ah, &ctx);
+  wots_pk_init (ctx, keypair, &ah);
 
   for (i = 0, csum = 15*32; i < _SLH_DSA_128_SIZE; i++)
     {
       uint8_t m0, m1;
       m0 = msg[i] >> 4;
       csum -= m0;
-      wots_verify_one (tree_ctx, &ctx, keypair, 2*i, m0, signature);
+      wots_verify_one (ctx, keypair, 2*i, m0, signature);
 
       m1 = msg[i] & 0xf;
       csum -= m1;
-      wots_verify_one (tree_ctx, &ctx, keypair, 2*i + 1, m1, signature);
+      wots_verify_one (ctx, keypair, 2*i + 1, m1, signature);
     }
 
-  wots_verify_one (tree_ctx, &ctx, keypair, 32, csum >> 8, signature);
-  wots_verify_one (tree_ctx, &ctx, keypair, 33, (csum >> 4) & 0xf, signature);
-  wots_verify_one (tree_ctx, &ctx, keypair, 34, csum & 0xf, signature);
+  wots_verify_one (ctx, keypair, 32, csum >> 8, signature);
+  wots_verify_one (ctx, keypair, 33, (csum >> 4) & 0xf, signature);
+  wots_verify_one (ctx, keypair, 34, csum & 0xf, signature);
 
-  sha3_256_shake (&ctx, _SLH_DSA_128_SIZE, pub);
+  ctx->hash->digest (ctx->scratch, pub);
 }
index dea360a8b41ab7c75e26a6ec7124a850f0eaf37b..b8f33d1031bba474d14fb88759bbf3b43e9bc5b6 100644 (file)
 static void
 xmss_leaf (const struct slh_merkle_ctx_secret *ctx, unsigned idx, uint8_t *leaf)
 {
-  _wots_gen (ctx->pub.tree_ctx, ctx->secret_seed, idx, leaf);
+  _wots_gen (&ctx->pub.ctx, ctx->secret_seed, idx, leaf);
 }
 
 static void
 xmss_node (const struct slh_merkle_ctx_public *ctx, unsigned height, unsigned index,
           const uint8_t *left, const uint8_t *right, uint8_t *out)
 {
-  struct sha3_ctx sha3 = *ctx->tree_ctx;
   struct slh_address_hash ah =
     {
       bswap32_if_le (SLH_XMSS_TREE),
@@ -57,33 +56,32 @@ xmss_node (const struct slh_merkle_ctx_public *ctx, unsigned height, unsigned in
       bswap32_if_le (height),
       bswap32_if_le (index),
     };
-
-  sha3_256_update (&sha3, sizeof (ah), (const uint8_t *) &ah);
-  sha3_256_update (&sha3, _SLH_DSA_128_SIZE, left);
-  sha3_256_update (&sha3, _SLH_DSA_128_SIZE, right);
-  sha3_256_shake (&sha3, _SLH_DSA_128_SIZE, out);
+  ctx->ctx.hash->node (ctx->ctx.tree, &ah, left, right, out);
 }
 
 void
-_xmss_gen (const uint8_t *public_seed, const uint8_t *secret_seed,
+_xmss_gen (const struct slh_hash *hash,
+          void *ha, void *hb,
+          const uint8_t *public_seed, const uint8_t *secret_seed,
           const struct slh_xmss_params *xmss,
           uint8_t *scratch, uint8_t *root)
 {
-  struct sha3_ctx tree_ctx;
   struct slh_merkle_ctx_secret ctx =
     {
-      { &tree_ctx, 0 },
+      { { hash, ha, hb }, 0 },
       secret_seed
     };
-  _slh_shake_init (&tree_ctx, public_seed, xmss->d - 1, 0);
+  hash->init(ha, public_seed, xmss->d - 1, 0);
   _merkle_root (&ctx, xmss_leaf, xmss_node, xmss->h, 0, root, scratch);
 }
 
+/* TODO: Like _wots_sign, pass hash context and secret_seed as
+   separate arguments. */
 void
 _xmss_sign (const struct slh_merkle_ctx_secret *ctx, unsigned h,
            unsigned idx, const uint8_t *msg, uint8_t *signature, uint8_t *pub)
 {
-  _wots_sign (ctx->pub.tree_ctx, ctx->secret_seed, idx, msg, signature, pub);
+  _wots_sign (&ctx->pub.ctx, ctx->secret_seed, idx, msg, signature, pub);
   signature += WOTS_SIGNATURE_SIZE;
 
   _merkle_sign (ctx, xmss_leaf, xmss_node, h, idx, signature);
@@ -94,7 +92,7 @@ void
 _xmss_verify (const struct slh_merkle_ctx_public *ctx, unsigned h,
              unsigned idx, const uint8_t *msg, const uint8_t *signature, uint8_t *pub)
 {
-  _wots_verify (ctx->tree_ctx, idx, msg, signature, pub);
+  _wots_verify (&ctx->ctx, idx, msg, signature, pub);
   signature += WOTS_SIGNATURE_SIZE;
 
   _merkle_verify (ctx, xmss_node, h, idx, signature, pub);