]> git.ipfire.org Git - thirdparty/nettle.git/commitdiff
Refactor slh hash abstraction; use union for the context.
authorNiels Möller <nisse@lysator.liu.se>
Wed, 2 Jul 2025 17:55:27 +0000 (19:55 +0200)
committerNiels Möller <nisse@lysator.liu.se>
Sat, 5 Jul 2025 08:14:55 +0000 (10:14 +0200)
slh-dsa-internal.h
slh-dsa-shake-128f.c
slh-dsa-shake-128s.c
slh-dsa.c
slh-fors.c
slh-sha256.c
slh-shake.c
slh-wots.c
slh-xmss.c
testsuite/slh-dsa-test.c

index 1fd77c084c5e50df7f6fa164fdd92d22e9c893db..a6c066aa3108b89446514ed332c853726728b4c4 100644 (file)
 #include <stdint.h>
 
 #include "nettle-types.h"
+#include "sha2.h"
+#include "sha3.h"
 
 /* Name mangling */
-#define _slh_shake_init _nettle_slh_shake_init
-#define _slh_shake _nettle_slh_shake
-#define _slh_shake_digest _nettle_slh_shake_digest
 #define _slh_shake_randomizer _nettle_slh_shake_randomizer
 #define _slh_shake_msg_digest _nettle_slh_shake_msg_digest
-#define _slh_sha256_init _nettle_slh_sha256_init
-#define _slh_sha256 _nettle_slh_sha256
 #define _slh_sha256_randomizer _nettle_slh_sha256_randomizer
-#define _slh_sha256_digest _nettle_slh_sha256_digest
+#define _slh_sha256_msg_digest _nettle_slh_sha256_msg_digest
 #define _wots_gen _nettle_wots_gen
 #define _wots_sign _nettle_wots_sign
 #define _wots_verify _nettle_wots_verify
@@ -64,6 +61,9 @@
 #define _slh_dsa_shake_128s_params _nettle_slh_dsa_shake_128s_params
 #define _slh_dsa_shake_128f_params _nettle_slh_dsa_shake_128f_params
 
+#define _slh_hash_shake _nettle_slh_hash_shake
+#define _slh_hash_sha256 _nettle_slh_hash_sha256
+
 /* Size of a single hash, including the seed and prf parameters */
 #define _SLH_DSA_128_SIZE 16
 
@@ -89,40 +89,42 @@ enum slh_addr_type
     SLH_FORS_PRF = 6,
   };
 
-typedef void slh_hash_init_func (void *tree_ctx, const uint8_t *public_seed,
-                                uint32_t layer, uint64_t tree_idx);
-typedef void slh_hash_secret_func (const void *tree_ctx,
+union slh_hash_ctx
+{
+  struct sha256_ctx sha256;
+  struct sha3_ctx sha3;
+};
+
+typedef void slh_hash_init_tree_func (union slh_hash_ctx *tree_ctx, const uint8_t *public_seed,
+                                     uint32_t layer, uint64_t tree_idx);
+typedef void slh_hash_init_hash_func (const union slh_hash_ctx *tree_ctx, union slh_hash_ctx *ctx,
+                                     const struct slh_address_hash *ah);
+typedef void slh_hash_secret_func (const union slh_hash_ctx *tree_ctx,
                                   const struct slh_address_hash *ah,
                                   const uint8_t *secret, uint8_t *out);
-typedef void slh_hash_node_func (const void *tree_ctx,
+typedef void slh_hash_node_func (const union slh_hash_ctx *tree_ctx,
                                 const struct slh_address_hash *ah,
                                 const uint8_t *left, const uint8_t *right,
                                 uint8_t *out);
-typedef void slh_hash_start_func (const void *tree_ctx, void *ctx, const struct slh_address_hash *ah);
 
 struct slh_hash
 {
-  slh_hash_init_func *init;
-  slh_hash_secret_func *secret;
-  slh_hash_node_func *node;
-  slh_hash_start_func *start;
+  slh_hash_init_tree_func *init_tree;
+  slh_hash_init_hash_func *init_hash;
   nettle_hash_update_func *update;
   nettle_hash_digest_func *digest;
+  slh_hash_secret_func *secret;
+  slh_hash_node_func *node;
 };
 
 extern const struct slh_hash _slh_hash_shake;
-struct slh_hash_ctxs
-{
-  const struct slh_hash *hash;
-  /* Initialized based on public seed and slh_address_tree. */
-  const void *tree;
-  /* Working ctx for wots and fors. */
-  void *scratch;
-};
+extern const struct slh_hash _slh_hash_sha256;
 
 struct slh_merkle_ctx_public
 {
-  struct slh_hash_ctxs ctx;
+  const struct slh_hash *hash;
+  /* Initialized using hash->init_tree. */
+  union slh_hash_ctx tree_ctx;
   unsigned keypair; /* Used only by fors_leaf and fors_node. */
 };
 
@@ -143,6 +145,7 @@ struct slh_fors_params
 {
   unsigned short a; /* Height of tree. */
   unsigned short k; /* Number of trees. */
+  unsigned short msg_size; /* Currently used only by tests. */
   unsigned short signature_size;
 };
 
@@ -152,20 +155,8 @@ struct slh_dsa_params
   struct slh_fors_params fors;
 };
 
-extern const struct slh_dsa_params _slh_dsa_128s_params;
-
-struct sha3_ctx;
-void
-_slh_shake_init (struct sha3_ctx *ctx, const uint8_t *public_seed,
-                uint32_t layer, uint64_t tree_idx);
-
-void
-_slh_shake (const struct sha3_ctx *tree_ctx,
-           const struct slh_address_hash *ah,
-           const uint8_t *secret, uint8_t *out);
-
-void
-_slh_shake_digest (struct sha3_ctx *ctx, uint8_t *out);
+extern const struct slh_dsa_params _slh_dsa_shake_128s_params;
+extern const struct slh_dsa_params _slh_dsa_shake_128f_params;
 
 void
 _slh_shake_randomizer (const uint8_t *public_seed, const uint8_t *secret_prf,
@@ -176,16 +167,6 @@ _slh_shake_msg_digest (const uint8_t *randomizer, const uint8_t *pub,
                       size_t length, const uint8_t *msg,
                       size_t digest_size, uint8_t *digest);
 
-struct sha256_ctx;
-void
-_slh_sha256_init (struct sha256_ctx *ctx, const uint8_t *public_seed,
-                uint32_t layer, uint64_t tree_idx);
-
-void
-_slh_sha256 (const struct sha256_ctx *tree_ctx,
-            const struct slh_address_hash *ah,
-            const uint8_t *secret, uint8_t *out);
-
 void
 _slh_sha256_randomizer (const uint8_t *public_seed, const uint8_t *secret_prf,
                        size_t msg_length, const uint8_t *msg,
@@ -200,21 +181,23 @@ _slh_sha256_msg_digest (const uint8_t *randomizer, const uint8_t *pub,
 #define WOTS_SIGNATURE_SIZE (_WOTS_SIGNATURE_LENGTH*_SLH_DSA_128_SIZE)
 
 void
-_wots_gen (const struct slh_hash_ctxs *ctx, const uint8_t *secret_seed,
+_wots_gen (const struct slh_hash *hash, const union slh_hash_ctx *tree_ctx,
+          const uint8_t *secret_seed,
           uint32_t keypair, uint8_t *pub);
 
 void
-_wots_sign (const struct slh_hash_ctxs *ctx, const uint8_t *secret_seed,
-           unsigned keypair, const uint8_t *msg, uint8_t *signature, uint8_t *pub);
+_wots_sign (const struct slh_hash *hash, const union slh_hash_ctx *tree_ctx,
+           const uint8_t *secret_seed,
+           unsigned keypair, const uint8_t *msg,
+           uint8_t *signature, uint8_t *pub);
 
 /* Computes candidate public key from signature. */
 void
-_wots_verify (struct slh_hash_ctxs *ctx,
+_wots_verify (const struct slh_hash *hash, const union slh_hash_ctx *tree_ctx,
              unsigned keypair, const uint8_t *msg, const uint8_t *signature, uint8_t *pub);
 
-/* Merkle tree functions. Leaf function uses a non-const context, to allow the ctx to point at
-   working storage. Could be generalized for other merkle tree
-   applications, by using void * for the ctx argument. */
+/* Merkle tree functions. Could be generalized for other merkle tree
+   applications, by using const void* for the ctx argument. */
 typedef void merkle_leaf_hash_func (const struct slh_merkle_ctx_secret *ctx, unsigned index, uint8_t *out);
 typedef void merkle_node_hash_func (const struct slh_merkle_ctx_public *ctx, unsigned height, unsigned index,
                                    const uint8_t *left, const uint8_t *right, uint8_t *out);
@@ -261,7 +244,6 @@ _fors_verify (const struct slh_merkle_ctx_public *ctx,
 /* Provided scratch must be of size (xmss->h + 1) * _SLH_DSA_128_SIZE. */
 void
 _xmss_gen (const struct slh_hash *hash,
-          void *ha, void *hb,
           const uint8_t *public_seed, const uint8_t *secret_seed,
           const struct slh_xmss_params *xmss,
           uint8_t *scratch, uint8_t *root);
@@ -279,7 +261,6 @@ _xmss_verify (const struct slh_merkle_ctx_public *ctx, unsigned h,
 void
 _slh_dsa_sign (const struct slh_dsa_params *params,
               const struct slh_hash *hash,
-              void *ha, void *hb,
               const uint8_t *pub, const uint8_t *priv,
               const uint8_t *digest,
               uint64_t tree_idx, unsigned leaf_idx,
@@ -287,7 +268,6 @@ _slh_dsa_sign (const struct slh_dsa_params *params,
 int
 _slh_dsa_verify (const struct slh_dsa_params *params,
                 const struct slh_hash *hash,
-                void *ha, void *hb,
                 const uint8_t *pub,
                 const uint8_t *digest, uint64_t tree_idx, unsigned leaf_idx,
                 const uint8_t *signature);
index 3d7da32e976538fe8045012b7e902e4f067c70fc..a3287675524f7023245372745a20190777bc27f8 100644 (file)
@@ -38,8 +38,6 @@
 #include "slh-dsa.h"
 #include "slh-dsa-internal.h"
 
-#include "sha3.h"
-
 #define SLH_DSA_M 34
 
 #define SLH_DSA_D 22
@@ -55,7 +53,7 @@ const struct slh_dsa_params
 _slh_dsa_shake_128f_params =
   {
     { SLH_DSA_D, XMSS_H, XMSS_SIGNATURE_SIZE (XMSS_H) },
-    { FORS_A, FORS_K, FORS_SIGNATURE_SIZE (FORS_A, FORS_K) },
+    { FORS_A, FORS_K, FORS_MSG_SIZE, FORS_SIGNATURE_SIZE (FORS_A, FORS_K) },
   };
 
 void
@@ -63,8 +61,7 @@ slh_dsa_shake_128f_root (const uint8_t *public_seed, const uint8_t *private_seed
                         uint8_t *root)
 {
   uint8_t scratch[(XMSS_H + 1)*_SLH_DSA_128_SIZE];
-  struct sha3_ctx ha, hb;
-  _xmss_gen (&_slh_hash_shake, &ha, &hb, public_seed, private_seed,
+  _xmss_gen (&_slh_hash_shake, public_seed, private_seed,
             &_slh_dsa_shake_128f_params.xmss, scratch, root);
 }
 
@@ -113,14 +110,12 @@ slh_dsa_shake_128f_sign (const uint8_t *pub, const uint8_t *priv,
   uint8_t digest[SLH_DSA_M];
   uint64_t tree_idx;
   unsigned leaf_idx;
-  struct sha3_ctx ha, hb;
 
   _slh_shake_randomizer (pub, priv + _SLH_DSA_128_SIZE, length, msg, signature);
   _slh_shake_msg_digest (signature, pub, length, msg, SLH_DSA_M, digest);
   parse_digest (digest + FORS_MSG_SIZE, &tree_idx, &leaf_idx);
 
-  _slh_dsa_sign (&_slh_dsa_shake_128f_params,
-                &_slh_hash_shake, &ha, &hb,
+  _slh_dsa_sign (&_slh_dsa_shake_128f_params, &_slh_hash_shake,
                 pub, priv, digest, tree_idx, leaf_idx,
                 signature + _SLH_DSA_128_SIZE);
 }
@@ -133,12 +128,10 @@ slh_dsa_shake_128f_verify (const uint8_t *pub,
   uint8_t digest[SLH_DSA_M];
   uint64_t tree_idx;
   unsigned leaf_idx;
-  struct sha3_ctx ha, hb;
 
   _slh_shake_msg_digest (signature, pub, length, msg, SLH_DSA_M,digest);
   parse_digest (digest + FORS_MSG_SIZE, &tree_idx, &leaf_idx);
-  return _slh_dsa_verify (&_slh_dsa_shake_128f_params,
-                         &_slh_hash_shake, &ha, &hb,
+  return _slh_dsa_verify (&_slh_dsa_shake_128f_params, &_slh_hash_shake,
                          pub, digest, tree_idx, leaf_idx,
                          signature + _SLH_DSA_128_SIZE);
 }
index f34c3bd26c1e06e5a2ffcd1dbef802924d1bfd07..6a34c55f168038066d629173071ecb3c8726f6f4 100644 (file)
@@ -38,8 +38,6 @@
 #include "slh-dsa.h"
 #include "slh-dsa-internal.h"
 
-#include "sha3.h"
-
 #define SLH_DSA_M 30
 
 #define SLH_DSA_D 7
@@ -55,7 +53,7 @@ const struct slh_dsa_params
 _slh_dsa_shake_128s_params =
   {
     { SLH_DSA_D, XMSS_H, XMSS_SIGNATURE_SIZE (XMSS_H) },
-    { FORS_A, FORS_K, FORS_SIGNATURE_SIZE (FORS_A, FORS_K) },
+    { FORS_A, FORS_K, FORS_MSG_SIZE, FORS_SIGNATURE_SIZE (FORS_A, FORS_K) },
   };
 
 void
@@ -63,8 +61,7 @@ slh_dsa_shake_128s_root (const uint8_t *public_seed, const uint8_t *private_seed
                         uint8_t *root)
 {
   uint8_t scratch[(XMSS_H + 1)*_SLH_DSA_128_SIZE];
-  struct sha3_ctx ha, hb;
-  _xmss_gen (&_slh_hash_shake, &ha, &hb, public_seed, private_seed,
+  _xmss_gen (&_slh_hash_shake, public_seed, private_seed,
             &_slh_dsa_shake_128s_params.xmss, scratch, root);
 }
 
@@ -113,14 +110,12 @@ slh_dsa_shake_128s_sign (const uint8_t *pub, const uint8_t *priv,
   uint8_t digest[SLH_DSA_M];
   uint64_t tree_idx;
   unsigned leaf_idx;
-  struct sha3_ctx ha, hb;
 
   _slh_shake_randomizer (pub, priv + _SLH_DSA_128_SIZE, length, msg, signature);
   _slh_shake_msg_digest (signature, pub, length, msg, SLH_DSA_M, digest);
   parse_digest (digest + FORS_MSG_SIZE, &tree_idx, &leaf_idx);
 
-  _slh_dsa_sign (&_slh_dsa_shake_128s_params,
-                &_slh_hash_shake, &ha, &hb,
+  _slh_dsa_sign (&_slh_dsa_shake_128s_params, &_slh_hash_shake,
                 pub, priv, digest, tree_idx, leaf_idx,
                 signature + _SLH_DSA_128_SIZE);
 }
@@ -133,12 +128,10 @@ slh_dsa_shake_128s_verify (const uint8_t *pub,
   uint8_t digest[SLH_DSA_M];
   uint64_t tree_idx;
   unsigned leaf_idx;
-  struct sha3_ctx ha, hb;
 
   _slh_shake_msg_digest (signature, pub, length, msg, SLH_DSA_M,digest);
   parse_digest (digest + FORS_MSG_SIZE, &tree_idx, &leaf_idx);
-  return _slh_dsa_verify (&_slh_dsa_shake_128s_params,
-                         &_slh_hash_shake, &ha, &hb,
+  return _slh_dsa_verify (&_slh_dsa_shake_128s_params, &_slh_hash_shake,
                          pub, digest, tree_idx, leaf_idx,
                          signature + _SLH_DSA_128_SIZE);
 }
index 312b7aa050e0f1512238bad791e9a0b4f4c9bdbb..84f7f5e3d5faa3b8ec078f8f7200d2d87bda28cd 100644 (file)
--- a/slh-dsa.c
+++ b/slh-dsa.c
 #include "slh-dsa.h"
 #include "slh-dsa-internal.h"
 
-#if 0
-/* For 128s flavor. */
-#define SLH_DSA_M 30
-
-#define SLH_DSA_D 7
-#define XMSS_H 9
-
-/* Use k Merkle trees, each of size 2^a. Signs messages of size
-   k * a = 168 bits or 21 octets. */
-#define FORS_A 12
-#define FORS_K 14
-
-const struct slh_dsa_params
-_slh_dsa_128s_params =
-  {
-    { SLH_DSA_D, XMSS_H, XMSS_SIGNATURE_SIZE (XMSS_H) },
-    { FORS_A, FORS_K, FORS_SIGNATURE_SIZE (FORS_A, FORS_K) },
-  };
-#endif
-
 void
 _slh_dsa_sign (const struct slh_dsa_params *params,
               const struct slh_hash *hash,
-              void *ha, void *hb,
               const uint8_t *pub, const uint8_t *priv,
               const uint8_t *digest,
               uint64_t tree_idx, unsigned leaf_idx,
@@ -76,10 +55,10 @@ _slh_dsa_sign (const struct slh_dsa_params *params,
 
   struct slh_merkle_ctx_secret merkle_ctx =
     {
-      { { hash, ha, hb }, leaf_idx },
+      { hash, {}, leaf_idx },
       priv,
     };
-  hash->init(ha, pub, 0, tree_idx);
+  hash->init_tree (&merkle_ctx.pub.tree_ctx, pub, 0, tree_idx);
 
   _fors_sign (&merkle_ctx, &params->fors, digest, signature, root);
   signature += params->fors.signature_size;
@@ -93,7 +72,7 @@ _slh_dsa_sign (const struct slh_dsa_params *params,
       leaf_idx = tree_idx & ((1 << params->xmss.h) - 1);
       tree_idx >>= params->xmss.h;
 
-      hash->init(ha, pub, i, tree_idx);
+      hash->init_tree (&merkle_ctx.pub.tree_ctx, pub, i, tree_idx);
 
       _xmss_sign (&merkle_ctx, params->xmss.h, leaf_idx, root, signature, root);
     }
@@ -103,7 +82,6 @@ _slh_dsa_sign (const struct slh_dsa_params *params,
 int
 _slh_dsa_verify (const struct slh_dsa_params *params,
                 const struct slh_hash *hash,
-                void *ha, void *hb,
                 const uint8_t *pub,
                 const uint8_t *digest, uint64_t tree_idx, unsigned leaf_idx,
                 const uint8_t *signature)
@@ -112,9 +90,9 @@ _slh_dsa_verify (const struct slh_dsa_params *params,
   int i;
 
   struct slh_merkle_ctx_public merkle_ctx =
-    { { hash, ha, hb }, leaf_idx };
+    { hash, {}, leaf_idx };
 
-  hash->init(ha, pub, 0, tree_idx);
+  hash->init_tree (&merkle_ctx.tree_ctx, pub, 0, tree_idx);
 
   _fors_verify (&merkle_ctx, &params->fors, digest, signature, root);
   signature += params->fors.signature_size;
@@ -128,7 +106,7 @@ _slh_dsa_verify (const struct slh_dsa_params *params,
       leaf_idx = tree_idx & ((1 << params->xmss.h) - 1);
       tree_idx >>= params->xmss.h;
 
-      hash->init(ha, pub, i, tree_idx);
+      hash->init_tree (&merkle_ctx.tree_ctx, pub, i, tree_idx);
 
       _xmss_verify (&merkle_ctx, params->xmss.h, leaf_idx, root, signature, root);
     }
index 28eff50cff1b00b74b60f48dd228e332346139c0..2b03b7f4a09b97cc8c96514843d239ae91808e68 100644 (file)
@@ -41,7 +41,6 @@
 #include "sha3.h"
 #include "slh-dsa-internal.h"
 
-/* TODO: Like wots_gen, take hash ctx, secret seed (+ keypair) as separate arguments? */
 void
 _fors_gen (const struct slh_merkle_ctx_secret *ctx,
           unsigned idx, uint8_t *sk, uint8_t *leaf)
@@ -54,10 +53,10 @@ _fors_gen (const struct slh_merkle_ctx_secret *ctx,
       bswap32_if_le (idx),
     };
 
-  ctx->pub.ctx.hash->secret (ctx->pub.ctx.tree, &ah, ctx->secret_seed, sk);
+  ctx->pub.hash->secret (&ctx->pub.tree_ctx, &ah, ctx->secret_seed, sk);
 
   ah.type = bswap32_if_le (SLH_FORS_TREE);
-  ctx->pub.ctx.hash->secret (ctx->pub.ctx.tree, &ah, ctx->secret_seed, leaf);
+  ctx->pub.hash->secret (&ctx->pub.tree_ctx, &ah, sk, leaf);
 }
 
 static void
@@ -77,22 +76,22 @@ fors_node (const struct slh_merkle_ctx_public *ctx, unsigned height, unsigned in
       bswap32_if_le (height),
       bswap32_if_le (index),
     };
-  ctx->ctx.hash->node (ctx->ctx.tree, &ah, left, right, out);
+  ctx->hash->node (&ctx->tree_ctx, &ah, left, right, out);
 }
 
 static void
 fors_sign_one (const struct slh_merkle_ctx_secret *ctx, unsigned a,
-              unsigned idx, uint8_t *signature)
+              unsigned idx, uint8_t *signature, union slh_hash_ctx *pub)
 {
   uint8_t hash[_SLH_DSA_128_SIZE];
 
   _fors_gen (ctx, idx, signature, hash);
+  signature += _SLH_DSA_128_SIZE;
 
-  _merkle_sign (ctx, fors_leaf, fors_node, a, idx,
-               signature + _SLH_DSA_128_SIZE);
-  _merkle_verify (&ctx->pub, fors_node, a, idx, signature + _SLH_DSA_128_SIZE, hash);
+  _merkle_sign (ctx, fors_leaf, fors_node, a, idx, signature);
+  _merkle_verify (&ctx->pub, fors_node, a, idx, signature, hash);
 
-  ctx->pub.ctx.hash->update (ctx->pub.ctx.scratch, _SLH_DSA_128_SIZE, hash);
+  ctx->pub.hash->update (pub, _SLH_DSA_128_SIZE, hash);
 }
 
 void
@@ -106,10 +105,11 @@ _fors_sign (const struct slh_merkle_ctx_secret *ctx,
       bswap32_if_le (ctx->pub.keypair),
       0, 0,
     };
+  union slh_hash_ctx pub_ctx;
   unsigned i, w, bits;
   unsigned mask = (1 << fors->a) - 1;
 
-  ctx->pub.ctx.hash->start (ctx->pub.ctx.tree, ctx->pub.ctx.scratch, &ah);
+  ctx->pub.hash->init_hash (&ctx->pub.tree_ctx, &pub_ctx, &ah);
 
   for (i = w = bits = 0; i < fors->k; i++, signature += (fors->a + 1) * _SLH_DSA_128_SIZE)
     {
@@ -117,15 +117,15 @@ _fors_sign (const struct slh_merkle_ctx_secret *ctx,
        w = (w << 8) | *msg++;
       bits -= fors->a;
 
-      fors_sign_one (ctx, fors->a, (i << fors->a) + ((w >> bits) & mask), signature);
+      fors_sign_one (ctx, fors->a, (i << fors->a) + ((w >> bits) & mask), signature, &pub_ctx);
      }
 
-  ctx->pub.ctx.hash->digest (ctx->pub.ctx.scratch, pub);
+  ctx->pub.hash->digest (&pub_ctx, pub);
 }
 
 static void
 fors_verify_one (const struct slh_merkle_ctx_public *ctx, unsigned a,
-                unsigned idx, const uint8_t *signature)
+                unsigned idx, const uint8_t *signature, union slh_hash_ctx *pub)
 {
   uint8_t root[_SLH_DSA_128_SIZE];
   struct slh_address_hash ah =
@@ -136,10 +136,10 @@ fors_verify_one (const struct slh_merkle_ctx_public *ctx, unsigned a,
       bswap32_if_le (idx),
     };
 
-  ctx->ctx.hash->secret (ctx->ctx.tree, &ah, signature, root);
+  ctx->hash->secret (&ctx->tree_ctx, &ah, signature, root);
   _merkle_verify (ctx, fors_node, a, idx, signature + _SLH_DSA_128_SIZE, root);
 
-  ctx->ctx.hash->update (ctx->ctx.scratch, _SLH_DSA_128_SIZE, root);
+  ctx->hash->update (pub, _SLH_DSA_128_SIZE, root);
 }
 
 void
@@ -153,10 +153,11 @@ _fors_verify (const struct slh_merkle_ctx_public *ctx,
       bswap32_if_le (ctx->keypair),
       0, 0,
     };
+  union slh_hash_ctx pub_ctx;
   unsigned i, w, bits;
   unsigned mask = (1 << fors->a) - 1;
 
-  ctx->ctx.hash->start (ctx->ctx.tree, ctx->ctx.scratch, &ah);
+  ctx->hash->init_hash (&ctx->tree_ctx, &pub_ctx, &ah);
 
   for (i = w = bits = 0; i < fors->k; i++, signature += (fors->a + 1) * _SLH_DSA_128_SIZE)
     {
@@ -164,7 +165,7 @@ _fors_verify (const struct slh_merkle_ctx_public *ctx,
        w = (w << 8) | *msg++;
       bits -= fors->a;
 
-      fors_verify_one (ctx, fors->a, (i << fors->a) + ((w >> bits) & mask), signature);
+      fors_verify_one (ctx, fors->a, (i << fors->a) + ((w >> bits) & mask), signature, &pub_ctx);
     }
-  ctx->ctx.hash->digest (ctx->ctx.scratch, pub);
+  ctx->hash->digest (&pub_ctx, pub);
 }
index d19749cabcb338ee70e4dc847b6a9e23bb6ab4f1..45901a76d94549a6a4cb9e6a20fab9471fac74b9 100644 (file)
@@ -51,9 +51,9 @@
 */
 
 /* All hashing but H_{msg} and PRF_{msg} use plain sha256: */
-void
-_slh_sha256_init (struct sha256_ctx *ctx, const uint8_t *public_seed,
-                 uint32_t layer, uint64_t tree_idx)
+static void
+slh_sha256_init_tree (struct sha256_ctx *ctx, const uint8_t *public_seed,
+                     uint32_t layer, uint64_t tree_idx)
 {
   static const uint8_t pad[48];
   uint8_t addr_layer;
@@ -62,28 +62,54 @@ _slh_sha256_init (struct sha256_ctx *ctx, const uint8_t *public_seed,
   sha256_init (ctx);
   sha256_update (ctx, _SLH_DSA_128_SIZE, public_seed);
   /* This padding completes a sha256 block. */
-  sha256_update (ctx, sizeof(pad), pad);
+  sha256_update (ctx, sizeof (pad), pad);
   /* Compressed address. */
   addr_layer = layer;
   sha256_update (ctx, 1, &addr_layer);
   addr_tree = bswap64_if_le (tree_idx);
-  sha256_update (ctx, sizeof(addr_tree), (const uint8_t *) &addr_tree);
+  sha256_update (ctx, sizeof (addr_tree), (const uint8_t *) &addr_tree);
 }
 
-void
-_slh_sha256 (const struct sha256_ctx *tree_ctx,
-            const struct slh_address_hash *ah,
-            const uint8_t *secret, uint8_t *out)
+static void
+slh_sha256_init_hash (const struct sha256_ctx *tree_ctx, struct sha256_ctx *ctx,
+                     const struct slh_address_hash *ah)
 {
-  struct sha256_ctx ctx = *tree_ctx;
-  uint8_t digest[SHA256_DIGEST_SIZE];
+  *ctx = *tree_ctx;
   /* For compressed addr, hash only last byte of the type. */
-  sha256_update (&ctx, sizeof(*ah) - 3, (const uint8_t *) ah + 3);
-  sha256_update (&ctx, _SLH_DSA_128_SIZE, secret);
-  sha256_digest (&ctx, digest);
+  sha256_update (ctx, sizeof (*ah) - 3, (const uint8_t *) ah + 3);
+}
+
+static void
+slh_sha256_digest (struct sha256_ctx *ctx, uint8_t *out)
+{
+  uint8_t digest[SHA256_DIGEST_SIZE];
+  sha256_digest (ctx, digest);
   memcpy (out, digest, _SLH_DSA_128_SIZE);
 }
 
+static void
+slh_sha256_secret (const struct sha256_ctx *tree_ctx,
+                  const struct slh_address_hash *ah,
+                  const uint8_t *secret, uint8_t *out)
+{
+  struct sha256_ctx ctx;
+  slh_sha256_init_hash (tree_ctx, &ctx, ah);
+  sha256_update (&ctx, _SLH_DSA_128_SIZE, secret);
+  slh_sha256_digest (&ctx, out);
+}
+
+static void
+slh_sha256_node (const struct sha256_ctx *tree_ctx,
+                const struct slh_address_hash *ah,
+                const uint8_t *left, const uint8_t *right, uint8_t *out)
+{
+  struct sha256_ctx ctx;
+  slh_sha256_init_hash (tree_ctx, &ctx, ah);
+  sha256_update (&ctx, _SLH_DSA_128_SIZE, left);
+  sha256_update (&ctx, _SLH_DSA_128_SIZE, right);
+  slh_sha256_digest (&ctx, out);
+}
+
 static const uint8_t slh_pure_prefix[2] = {0, 0};
 
 void
@@ -133,3 +159,14 @@ _slh_sha256_msg_digest (const uint8_t *randomizer, const uint8_t *pub,
       digest_size -= SHA256_DIGEST_SIZE;
     }
 }
+
+const struct slh_hash
+_slh_hash_sha256 =
+  {
+    (slh_hash_init_tree_func *) slh_sha256_init_tree,
+    (slh_hash_init_hash_func *) slh_sha256_init_hash,
+    (nettle_hash_update_func *) sha3_256_update,
+    (nettle_hash_digest_func *) slh_sha256_digest,
+    (slh_hash_secret_func *) slh_sha256_secret,
+    (slh_hash_node_func *) slh_sha256_node,
+  };
index 8a97168a13e04e57de37a8020e7dfb0ce109b914..3bffe078b19b6d488ffac3bae443920c5dd53d9e 100644 (file)
@@ -46,9 +46,9 @@ struct slh_address_tree
   uint64_t tree_idx;
 };
 
-void
-_slh_shake_init (struct sha3_ctx *ctx, const uint8_t *public_seed,
-                uint32_t layer, uint64_t tree_idx)
+static void
+slh_shake_init_tree (struct sha3_ctx *ctx, const uint8_t *public_seed,
+                    uint32_t layer, uint64_t tree_idx)
 {
   struct slh_address_tree at = { bswap32_if_le (layer), 0, bswap64_if_le (tree_idx) };
 
@@ -58,36 +58,36 @@ _slh_shake_init (struct sha3_ctx *ctx, const uint8_t *public_seed,
 }
 
 static void
-slh_shake_start (const struct sha3_ctx *tree_ctx, struct sha3_ctx *ctx,
-                const struct slh_address_hash *ah)
+slh_shake_init_hash (const struct sha3_ctx *tree_ctx, struct sha3_ctx *ctx,
+                    const struct slh_address_hash *ah)
 {
   *ctx = *tree_ctx;
   sha3_256_update (ctx, sizeof (*ah), (const uint8_t *) ah);
 }
 
-void
-_slh_shake (const struct sha3_ctx *tree_ctx, const struct slh_address_hash *ah,
-           const uint8_t *secret, uint8_t *out)
+static void
+slh_shake_secret (const struct sha3_ctx *tree_ctx, const struct slh_address_hash *ah,
+                 const uint8_t *secret, uint8_t *out)
 {
   struct sha3_ctx ctx;
-  slh_shake_start (tree_ctx, &ctx, ah);;
+  slh_shake_init_hash (tree_ctx, &ctx, ah);;
   sha3_256_update (&ctx, _SLH_DSA_128_SIZE, secret);
   sha3_256_shake (&ctx, _SLH_DSA_128_SIZE, out);
 }
 
 static void
 slh_shake_node (const struct sha3_ctx *tree_ctx, const struct slh_address_hash *ah,
-                const uint8_t *left, const uint8_t *right, uint8_t *out)
+               const uint8_t *left, const uint8_t *right, uint8_t *out)
 {
   struct sha3_ctx ctx;
-  slh_shake_start (tree_ctx, &ctx, ah);;
+  slh_shake_init_hash (tree_ctx, &ctx, ah);;
   sha3_256_update (&ctx, _SLH_DSA_128_SIZE, left);
   sha3_256_update (&ctx, _SLH_DSA_128_SIZE, right);
   sha3_256_shake (&ctx, _SLH_DSA_128_SIZE, out);
 }
 
-void
-_slh_shake_digest (struct sha3_ctx *ctx, uint8_t *out)
+static void
+slh_shake_digest (struct sha3_ctx *ctx, uint8_t *out)
 {
   sha3_256_shake (ctx, _SLH_DSA_128_SIZE, out);
 }
@@ -127,10 +127,10 @@ _slh_shake_msg_digest (const uint8_t *randomizer, const uint8_t *pub,
 const struct slh_hash
 _slh_hash_shake =
   {
-    (slh_hash_init_func *) _slh_shake_init,
-    (slh_hash_secret_func *) _slh_shake,
-    (slh_hash_node_func *) slh_shake_node,
-    (slh_hash_start_func *) slh_shake_start,
+    (slh_hash_init_tree_func *) slh_shake_init_tree,
+    (slh_hash_init_hash_func *) slh_shake_init_hash,
     (nettle_hash_update_func *) sha3_256_update,
-    (nettle_hash_digest_func *)_slh_shake_digest
+    (nettle_hash_digest_func *) slh_shake_digest,
+    (slh_hash_secret_func *) slh_shake_secret,
+    (slh_hash_node_func *) slh_shake_node,
   };
index bd591af98b8d58b4ed847f6e09032a78ca56c238..0f3d2c559043e25b4041fb5411f0232516707b27 100644 (file)
@@ -44,7 +44,7 @@
    dst. For the ah argument, leaves ah->keypair and ah->height_chain
    unchanged, but overwrites the other fields. */
 static const uint8_t *
-wots_chain (const struct slh_hash_ctxs *ctx,
+wots_chain (const struct slh_hash *hash, const union slh_hash_ctx *ctx,
            struct slh_address_hash *ah,
            unsigned i, unsigned s,
            const uint8_t *src, uint8_t *dst)
@@ -57,36 +57,39 @@ wots_chain (const struct slh_hash_ctxs *ctx,
   ah->type = bswap32_if_le (SLH_WOTS_HASH);
   ah->index_hash = bswap32_if_le (i);
 
-  ctx->hash->secret(ctx->tree, ah, src, dst);
+  hash->secret (ctx, ah, src, dst);
 
   for (j = 1; j < s; j++)
     {
       ah->index_hash = bswap32_if_le (i + j);
-      ctx->hash->secret (ctx->tree, ah, dst, dst);
+      hash->secret (ctx, ah, dst, dst);
     }
 
   return dst;
 }
 
 static void
-wots_pk_init (const struct slh_hash_ctxs *ctx,
-             unsigned keypair, struct slh_address_hash *ah)
+wots_pk_init (const struct slh_hash *hash, const union slh_hash_ctx *tree_ctx,
+             unsigned keypair, struct slh_address_hash *ah,
+             union slh_hash_ctx *ctx)
 {
   ah->type = bswap32_if_le (SLH_WOTS_PK);
   ah->keypair = bswap32_if_le (keypair);
   ah->height_chain = 0;
   ah->index_hash = 0;
-  ctx->hash->start(ctx->tree, ctx->scratch, ah);
+  hash->init_hash (tree_ctx, ctx, ah);
 }
 
 void
-_wots_gen (const struct slh_hash_ctxs *ctx, const uint8_t *secret_seed,
+_wots_gen (const struct slh_hash *hash, const union slh_hash_ctx *tree_ctx,
+          const uint8_t *secret_seed,
           uint32_t keypair, uint8_t *pub)
 {
   struct slh_address_hash ah;
+  union slh_hash_ctx ctx;
   unsigned i;
 
-  wots_pk_init (ctx, keypair, &ah);
+  wots_pk_init (hash, tree_ctx, keypair, &ah, &ctx);
 
   for (i = 0; i < _WOTS_SIGNATURE_LENGTH; i++)
     {
@@ -96,72 +99,76 @@ _wots_gen (const struct slh_hash_ctxs *ctx, const uint8_t *secret_seed,
       ah.type = bswap32_if_le (SLH_WOTS_PRF);
       ah.height_chain = bswap32_if_le (i);
       ah.index_hash = 0;
-      ctx->hash->secret (ctx->tree, &ah, secret_seed, out);
+      hash->secret (tree_ctx, &ah, secret_seed, out);
 
       /* Hash chain. */
-      wots_chain (ctx, &ah, 0, 15, out, out);
+      wots_chain (hash, tree_ctx, &ah, 0, 15, out, out);
 
-      ctx->hash->update (ctx->scratch, _SLH_DSA_128_SIZE, out);
+      hash->update (&ctx, _SLH_DSA_128_SIZE, out);
     }
-  ctx->hash->digest (ctx->scratch, pub);
+  hash->digest (&ctx, pub);
 }
 
 /* Produces signature hash corresponding to the ith message nybble. Modifies addr. */
 static void
-wots_sign_one (const struct slh_hash_ctxs *ctx, const uint8_t *secret_seed,
-              uint32_t keypair,
-              unsigned i, uint8_t msg, uint8_t *sig)
+wots_sign_one (const struct slh_hash *hash, const union slh_hash_ctx *tree_ctx,
+              const uint8_t *secret_seed, uint32_t keypair,
+              unsigned i, uint8_t msg,
+              uint8_t *signature, union slh_hash_ctx *ctx)
 {
   struct slh_address_hash ah;
   uint8_t pub[_SLH_DSA_128_SIZE];
-  sig += i*_SLH_DSA_128_SIZE;
+  signature += i*_SLH_DSA_128_SIZE;
 
   /* Generate secret value. */
   ah.type = bswap32_if_le (SLH_WOTS_PRF);
   ah.keypair = bswap32_if_le (keypair);
   ah.height_chain = bswap32_if_le (i);
   ah.index_hash = 0;
-  ctx->hash->secret(ctx->tree, &ah, secret_seed, sig);
+  hash->secret (tree_ctx, &ah, secret_seed, signature);
 
   /* Hash chain. */
-  wots_chain (ctx, &ah, 0, msg, sig, sig);
+  wots_chain (hash, tree_ctx, &ah, 0, msg, signature, signature);
 
-  ctx->hash->update (ctx->scratch, _SLH_DSA_128_SIZE,
-                    wots_chain (ctx, &ah, msg, 15 - msg, sig, pub));
+  hash->update (ctx, _SLH_DSA_128_SIZE,
+               wots_chain (hash, tree_ctx, &ah, msg, 15 - msg, signature, pub));
 }
 
 void
-_wots_sign (const struct slh_hash_ctxs *ctx, const uint8_t *secret_seed,
-           unsigned keypair, const uint8_t *msg, uint8_t *signature, uint8_t *pub)
+_wots_sign (const struct slh_hash *hash, const union slh_hash_ctx *tree_ctx,
+           const uint8_t *secret_seed, unsigned keypair, const uint8_t *msg,
+           uint8_t *signature, uint8_t *pub)
 {
   struct slh_address_hash ah;
+  union slh_hash_ctx ctx;
   unsigned i;
   uint32_t csum;
 
-  wots_pk_init (ctx, keypair, &ah);
+  wots_pk_init (hash, tree_ctx, keypair, &ah, &ctx);
 
   for (i = 0, csum = 15*32; i < _SLH_DSA_128_SIZE; i++)
     {
       uint8_t m0, m1;
       m0 = msg[i] >> 4;
       csum -= m0;
-      wots_sign_one (ctx, secret_seed, keypair, 2*i, m0, signature);
+      wots_sign_one (hash, tree_ctx, secret_seed, keypair, 2*i, m0, signature, &ctx);
 
       m1 = msg[i] & 0xf;
       csum -= m1;
-      wots_sign_one (ctx, secret_seed, keypair, 2*i + 1, m1, signature);
+      wots_sign_one (hash, tree_ctx, secret_seed, keypair, 2*i + 1, m1, signature, &ctx);
     }
 
-  wots_sign_one (ctx, secret_seed, keypair, 32, csum >> 8, signature);
-  wots_sign_one (ctx, secret_seed, keypair, 33, (csum >> 4) & 0xf, signature);
-  wots_sign_one (ctx, secret_seed, keypair, 34, csum & 0xf, signature);
+  wots_sign_one (hash, tree_ctx, secret_seed, keypair, 32, csum >> 8, signature, &ctx);
+  wots_sign_one (hash, tree_ctx, secret_seed, keypair, 33, (csum >> 4) & 0xf, signature, &ctx);
+  wots_sign_one (hash, tree_ctx, secret_seed, keypair, 34, csum & 0xf, signature, &ctx);
 
-  ctx->hash->digest (ctx->scratch, pub);
+  hash->digest (&ctx, pub);
 }
 
 static void
-wots_verify_one (struct slh_hash_ctxs *ctx,
-                uint32_t keypair, unsigned i, uint8_t msg, const uint8_t *signature)
+wots_verify_one (const struct slh_hash *hash, const union slh_hash_ctx *tree_ctx,
+                uint32_t keypair, unsigned i, uint8_t msg,
+                const uint8_t *signature, union slh_hash_ctx *ctx)
 {
   struct slh_address_hash ah;
   uint8_t out[_SLH_DSA_128_SIZE];
@@ -170,35 +177,36 @@ wots_verify_one (struct slh_hash_ctxs *ctx,
   ah.keypair = bswap32_if_le (keypair);
   ah.height_chain = bswap32_if_le (i);
 
-  ctx->hash->update(ctx->scratch, _SLH_DSA_128_SIZE,
-                   wots_chain (ctx, &ah, msg, 15 - msg, signature, out));
+  hash->update (ctx, _SLH_DSA_128_SIZE,
+               wots_chain (hash, tree_ctx, &ah, msg, 15 - msg, signature, out));
 }
 
 void
-_wots_verify (struct slh_hash_ctxs *ctx,
+_wots_verify (const struct slh_hash *hash, const union slh_hash_ctx *tree_ctx,
              unsigned keypair, const uint8_t *msg, const uint8_t *signature, uint8_t *pub)
 {
   struct slh_address_hash ah;
+  union slh_hash_ctx ctx;
   unsigned i;
   uint32_t csum;
 
-  wots_pk_init (ctx, keypair, &ah);
+  wots_pk_init (hash, tree_ctx, keypair, &ah, &ctx);
 
   for (i = 0, csum = 15*32; i < _SLH_DSA_128_SIZE; i++)
     {
       uint8_t m0, m1;
       m0 = msg[i] >> 4;
       csum -= m0;
-      wots_verify_one (ctx, keypair, 2*i, m0, signature);
+      wots_verify_one (hash, tree_ctx, keypair, 2*i, m0, signature, &ctx);
 
       m1 = msg[i] & 0xf;
       csum -= m1;
-      wots_verify_one (ctx, keypair, 2*i + 1, m1, signature);
+      wots_verify_one (hash, tree_ctx, keypair, 2*i + 1, m1, signature, &ctx);
     }
 
-  wots_verify_one (ctx, keypair, 32, csum >> 8, signature);
-  wots_verify_one (ctx, keypair, 33, (csum >> 4) & 0xf, signature);
-  wots_verify_one (ctx, keypair, 34, csum & 0xf, signature);
+  wots_verify_one (hash, tree_ctx, keypair, 32, csum >> 8, signature, &ctx);
+  wots_verify_one (hash, tree_ctx, keypair, 33, (csum >> 4) & 0xf, signature, &ctx);
+  wots_verify_one (hash, tree_ctx, keypair, 34, csum & 0xf, signature, &ctx);
 
-  ctx->hash->digest (ctx->scratch, pub);
+  hash->digest (&ctx, pub);
 }
index b8f33d1031bba474d14fb88759bbf3b43e9bc5b6..f58a56850a1833258f6f667eba4f45a32f149891 100644 (file)
@@ -42,7 +42,7 @@
 static void
 xmss_leaf (const struct slh_merkle_ctx_secret *ctx, unsigned idx, uint8_t *leaf)
 {
-  _wots_gen (&ctx->pub.ctx, ctx->secret_seed, idx, leaf);
+  _wots_gen (ctx->pub.hash, &ctx->pub.tree_ctx, ctx->secret_seed, idx, leaf);
 }
 
 static void
@@ -56,32 +56,29 @@ xmss_node (const struct slh_merkle_ctx_public *ctx, unsigned height, unsigned in
       bswap32_if_le (height),
       bswap32_if_le (index),
     };
-  ctx->ctx.hash->node (ctx->ctx.tree, &ah, left, right, out);
+  ctx->hash->node (&ctx->tree_ctx, &ah, left, right, out);
 }
 
 void
 _xmss_gen (const struct slh_hash *hash,
-          void *ha, void *hb,
           const uint8_t *public_seed, const uint8_t *secret_seed,
           const struct slh_xmss_params *xmss,
           uint8_t *scratch, uint8_t *root)
 {
   struct slh_merkle_ctx_secret ctx =
     {
-      { { hash, ha, hb }, 0 },
+      { hash, {}, 0 },
       secret_seed
     };
-  hash->init(ha, public_seed, xmss->d - 1, 0);
+  hash->init_tree (&ctx.pub.tree_ctx, public_seed, xmss->d - 1, 0);
   _merkle_root (&ctx, xmss_leaf, xmss_node, xmss->h, 0, root, scratch);
 }
 
-/* TODO: Like _wots_sign, pass hash context and secret_seed as
-   separate arguments. */
 void
 _xmss_sign (const struct slh_merkle_ctx_secret *ctx, unsigned h,
            unsigned idx, const uint8_t *msg, uint8_t *signature, uint8_t *pub)
 {
-  _wots_sign (&ctx->pub.ctx, ctx->secret_seed, idx, msg, signature, pub);
+  _wots_sign (ctx->pub.hash, &ctx->pub.tree_ctx, ctx->secret_seed, idx, msg, signature, pub);
   signature += WOTS_SIGNATURE_SIZE;
 
   _merkle_sign (ctx, xmss_leaf, xmss_node, h, idx, signature);
@@ -92,7 +89,7 @@ void
 _xmss_verify (const struct slh_merkle_ctx_public *ctx, unsigned h,
              unsigned idx, const uint8_t *msg, const uint8_t *signature, uint8_t *pub)
 {
-  _wots_verify (&ctx->ctx, idx, msg, signature, pub);
+  _wots_verify (ctx->hash, &ctx->tree_ctx, idx, msg, signature, pub);
   signature += WOTS_SIGNATURE_SIZE;
 
   _merkle_verify (ctx, xmss_node, h, idx, signature, pub);
index 081af29282304d68aaccadf329e9593a6f1c8ec1..1ccc62212097edd7275cb484597fda6c5205b009 100644 (file)
@@ -102,15 +102,15 @@ test_wots_gen (const struct tstring *public_seed, const struct tstring *secret_s
               unsigned layer, uint64_t tree_idx, uint32_t keypair,
               const struct tstring *exp_pub)
 {
-  struct sha3_ctx tree_ctx;
+  union slh_hash_ctx tree_ctx;
   uint8_t pub[_SLH_DSA_128_SIZE];
   ASSERT (public_seed->length == _SLH_DSA_128_SIZE);
   ASSERT (secret_seed->length == _SLH_DSA_128_SIZE);
   ASSERT (exp_pub->length == _SLH_DSA_128_SIZE);
 
-  _slh_shake_init (&tree_ctx, public_seed->data, layer, tree_idx);
+  _slh_hash_shake.init_tree (&tree_ctx, public_seed->data, layer, tree_idx);
 
-  _wots_gen (&tree_ctx, secret_seed->data, keypair, pub);
+  _wots_gen (&_slh_hash_shake, &tree_ctx, secret_seed->data, keypair, pub);
   mark_bytes_defined (sizeof (pub), pub);
   ASSERT (MEMEQ (sizeof (pub), pub, exp_pub->data));
 }
@@ -120,7 +120,7 @@ test_wots_sign (const struct tstring *public_seed, const struct tstring *secret_
                unsigned layer, uint64_t tree_idx, uint32_t keypair, const struct tstring *msg,
                const struct tstring *exp_pub, const struct tstring *exp_sig)
 {
-  struct sha3_ctx tree_ctx;
+  union slh_hash_ctx tree_ctx;
   uint8_t sig[WOTS_SIGNATURE_SIZE];
   uint8_t pub[_SLH_DSA_128_SIZE];
   ASSERT (public_seed->length == _SLH_DSA_128_SIZE);
@@ -129,9 +129,9 @@ test_wots_sign (const struct tstring *public_seed, const struct tstring *secret_
   ASSERT (exp_pub->length == _SLH_DSA_128_SIZE);
   ASSERT (exp_sig->length == WOTS_SIGNATURE_SIZE);
 
-  _slh_shake_init (&tree_ctx, public_seed->data, layer, tree_idx);
+  _slh_hash_shake.init_tree (&tree_ctx, public_seed->data, layer, tree_idx);
 
-  _wots_sign (&tree_ctx, secret_seed->data, keypair,
+  _wots_sign (&_slh_hash_shake, &tree_ctx, secret_seed->data, keypair,
              msg->data, sig, pub);
   mark_bytes_defined (sizeof (sig), sig);
   mark_bytes_defined (sizeof (pub), pub);
@@ -139,7 +139,7 @@ test_wots_sign (const struct tstring *public_seed, const struct tstring *secret_
   ASSERT (MEMEQ (sizeof (pub), pub, exp_pub->data));
 
   memset (pub, 0, sizeof (pub));
-  _wots_verify (&tree_ctx, keypair, msg->data, sig, pub);
+  _wots_verify (&_slh_hash_shake, &tree_ctx, keypair, msg->data, sig, pub);
   ASSERT (MEMEQ (sizeof (pub), pub, exp_pub->data));
 }
 
@@ -147,7 +147,7 @@ test_wots_sign (const struct tstring *public_seed, const struct tstring *secret_
 static void
 xmss_leaf (const struct slh_merkle_ctx_secret *ctx, unsigned idx, uint8_t *leaf)
 {
-  _wots_gen (ctx->pub.tree_ctx, ctx->secret_seed, idx, leaf);
+  _wots_gen (ctx->pub.hash, &ctx->pub.tree_ctx, ctx->secret_seed, idx, leaf);
   mark_bytes_defined (SLH_DSA_SHAKE_128_SEED_SIZE, leaf);
 }
 
@@ -155,7 +155,6 @@ static void
 xmss_node (const struct slh_merkle_ctx_public *ctx, unsigned height, unsigned index,
           const uint8_t *left, const uint8_t *right, uint8_t *out)
 {
-  struct sha3_ctx sha3 = *ctx->tree_ctx;
   struct slh_address_hash ah =
     {
       bswap32_if_le (SLH_XMSS_TREE),
@@ -164,10 +163,7 @@ xmss_node (const struct slh_merkle_ctx_public *ctx, unsigned height, unsigned in
       bswap32_if_le (index),
     };
 
-  sha3_256_update (&sha3, sizeof (ah), (const uint8_t *) &ah);
-  sha3_256_update (&sha3, _SLH_DSA_128_SIZE, left);
-  sha3_256_update (&sha3, _SLH_DSA_128_SIZE, right);
-  sha3_256_shake (&sha3, _SLH_DSA_128_SIZE, out);
+  ctx->hash->node (&ctx->tree_ctx, &ah, left, right, out);
 }
 
 static void
@@ -176,10 +172,9 @@ test_merkle (const struct tstring *public_seed, const struct tstring *secret_see
             unsigned layer, uint64_t tree_idx, uint32_t idx, const struct tstring *msg,
             const struct tstring *exp_pub, const struct tstring *exp_sig)
 {
-  struct sha3_ctx tree_ctx;
   struct slh_merkle_ctx_secret ctx =
     {
-      { &tree_ctx, 0 },
+      { &_slh_hash_shake, {}, 0 },
       secret_seed->data,
     };
 
@@ -192,7 +187,7 @@ test_merkle (const struct tstring *public_seed, const struct tstring *secret_see
   ASSERT (exp_pub->length == _SLH_DSA_128_SIZE);
   ASSERT (exp_sig->length == XMSS_AUTH_SIZE (h));
 
-  _slh_shake_init (&tree_ctx, public_seed->data, layer, tree_idx);
+  _slh_hash_shake.init_tree (&ctx.pub.tree_ctx, public_seed->data, layer, tree_idx);
 
   _merkle_sign (&ctx, xmss_leaf, xmss_node, h, idx, sig);
   ASSERT (MEMEQ (exp_sig->length, sig, exp_sig->data));
@@ -208,10 +203,9 @@ test_fors_gen (const struct tstring *public_seed, const struct tstring *secret_s
               unsigned layer, uint64_t tree_idx, unsigned keypair, unsigned idx,
               const struct tstring *exp_sk, const struct tstring *exp_leaf)
 {
-  struct sha3_ctx tree_ctx;
   struct slh_merkle_ctx_secret ctx =
     {
-      { &tree_ctx, keypair },
+      { &_slh_hash_shake, {}, keypair },
       secret_seed->data,
     };
   uint8_t sk[_SLH_DSA_128_SIZE];
@@ -221,7 +215,7 @@ test_fors_gen (const struct tstring *public_seed, const struct tstring *secret_s
   ASSERT (exp_sk->length == _SLH_DSA_128_SIZE);
   ASSERT (exp_leaf->length == _SLH_DSA_128_SIZE);
 
-  _slh_shake_init (&tree_ctx, public_seed->data, layer, tree_idx);
+  _slh_hash_shake.init_tree (&ctx.pub.tree_ctx, public_seed->data, layer, tree_idx);
 
   _fors_gen (&ctx, idx, sk, leaf);
   mark_bytes_defined (sizeof (sk), sk);
@@ -236,10 +230,9 @@ test_fors_sign (const struct tstring *public_seed, const struct tstring *secret_
                unsigned layer, uint64_t tree_idx, unsigned keypair, const struct tstring *msg,
                const struct tstring *exp_pub, const struct tstring *exp_sig)
 {
-  struct sha3_ctx tree_ctx;
   struct slh_merkle_ctx_secret ctx =
     {
-      { &tree_ctx, keypair },
+      { &_slh_hash_shake, {}, keypair },
       secret_seed->data,
     };
   uint8_t pub[_SLH_DSA_128_SIZE];
@@ -250,7 +243,7 @@ test_fors_sign (const struct tstring *public_seed, const struct tstring *secret_
   ASSERT (exp_pub->length == _SLH_DSA_128_SIZE);
   ASSERT (exp_sig->length == fors->signature_size);
 
-  _slh_shake_init (&tree_ctx, public_seed->data, layer, tree_idx);
+  _slh_hash_shake.init_tree (&ctx.pub.tree_ctx, public_seed->data, layer, tree_idx);
 
   _fors_sign (&ctx, fors, msg->data, sig, pub);
   mark_bytes_defined (exp_sig->length, sig);
@@ -270,10 +263,9 @@ test_xmss_sign (const struct tstring *public_seed, const struct tstring *secret_
                unsigned layer, uint64_t tree_idx, uint32_t idx, const struct tstring *msg,
                const struct tstring *exp_pub, const struct tstring *exp_sig)
 {
-  struct sha3_ctx tree_ctx;
   struct slh_merkle_ctx_secret ctx =
     {
-      { &tree_ctx, 0 },
+      { &_slh_hash_shake, {}, 0 },
       secret_seed->data,
     };
 
@@ -285,7 +277,7 @@ test_xmss_sign (const struct tstring *public_seed, const struct tstring *secret_
   ASSERT (exp_pub->length == _SLH_DSA_128_SIZE);
   ASSERT (exp_sig->length == XMSS_SIGNATURE_SIZE (xmss_h));
 
-  _slh_shake_init (&tree_ctx, public_seed->data, layer, tree_idx);
+  _slh_hash_shake.init_tree (&ctx.pub.tree_ctx, public_seed->data, layer, tree_idx);
 
   _xmss_sign (&ctx, xmss_h, idx, msg->data, sig, pub);
   mark_bytes_defined (sizeof (pub), pub);