]> git.ipfire.org Git - thirdparty/strongswan.git/commitdiff
ml: Add ML-DSA support
authorAndreas Steffen <andreas.steffen@strongswan.org>
Thu, 26 Dec 2024 10:42:51 +0000 (11:42 +0100)
committerAndreas Steffen <andreas.steffen@strongswan.org>
Fri, 18 Jul 2025 11:07:01 +0000 (13:07 +0200)
16 files changed:
src/libstrongswan/plugins/ml/Makefile.am
src/libstrongswan/plugins/ml/ml_dsa_params.c [new file with mode: 0644]
src/libstrongswan/plugins/ml/ml_dsa_params.h [new file with mode: 0644]
src/libstrongswan/plugins/ml/ml_dsa_poly.c [new file with mode: 0644]
src/libstrongswan/plugins/ml/ml_dsa_poly.h [new file with mode: 0644]
src/libstrongswan/plugins/ml/ml_dsa_private_key.c [new file with mode: 0644]
src/libstrongswan/plugins/ml/ml_dsa_private_key.h [new file with mode: 0644]
src/libstrongswan/plugins/ml/ml_dsa_public_key.c [new file with mode: 0644]
src/libstrongswan/plugins/ml/ml_dsa_public_key.h [new file with mode: 0644]
src/libstrongswan/plugins/ml/ml_kem.c
src/libstrongswan/plugins/ml/ml_kem_params.c [moved from src/libstrongswan/plugins/ml/ml_params.c with 98% similarity]
src/libstrongswan/plugins/ml/ml_kem_params.h [moved from src/libstrongswan/plugins/ml/ml_params.h with 94% similarity]
src/libstrongswan/plugins/ml/ml_plugin.c
src/libstrongswan/plugins/ml/ml_poly.h
src/libstrongswan/plugins/ml/ml_utils.c
src/libstrongswan/plugins/ml/ml_utils.h

index e0d8d2bac9996554de87ffdc84c181a4945a9c6f..155003f39ac2c9de1263a19a02e50d19939fab88 100644 (file)
@@ -13,7 +13,11 @@ endif
 libstrongswan_ml_la_SOURCES = \
        ml_bitpacker.c ml_bitpacker.h \
        ml_kem.c ml_kem.h \
-       ml_params.c ml_params.h \
+       ml_dsa_public_key.c ml_dsa_public_key.h \
+       ml_dsa_private_key.c ml_dsa_private_key.h \
+       ml_kem_params.c ml_kem_params.h \
+       ml_dsa_params.c ml_dsa_params.h \
        ml_plugin.h ml_plugin.c \
        ml_poly.c ml_poly.h \
+       ml_dsa_poly.c ml_dsa_poly.h \
        ml_utils.c ml_utils.h
diff --git a/src/libstrongswan/plugins/ml/ml_dsa_params.c b/src/libstrongswan/plugins/ml/ml_dsa_params.c
new file mode 100644 (file)
index 0000000..2ced4e9
--- /dev/null
@@ -0,0 +1,89 @@
+/*
+ * Copyright (C) 2024 Andreas Steffen
+ *
+ * Copyright (C) secunet Security Networks AG
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * for more details.
+ */
+
+#include "ml_dsa_params.h"
+
+/**
+ * Parameter sets for ML-DSA.
+ */
+static const ml_dsa_params_t ml_dsa_params[] = {
+       {
+               .type = KEY_ML_DSA_44,
+               .k = 4,
+               .l = 4,
+               .eta = 2,
+               .d = 3,
+               .gamma1_exp = 17,
+               .gamma2 = (ML_DSA_Q - 1) / 88,
+               .gamma2_d = 6,
+               .lambda = 128,
+               .tau = 39,
+               .beta = 78,
+               .omega = 80,
+               .privkey_len = 2560,
+               .sig_len = 2420,
+       },
+       {
+               .type = KEY_ML_DSA_65,
+               .k = 6,
+               .l = 5,
+               .eta = 4,
+               .d = 4,
+               .gamma1_exp = 19,
+               .gamma2_d = 4,
+               .gamma2 = (ML_DSA_Q - 1) / 32,
+               .lambda = 192,
+               .tau = 49,
+               .beta = 196,
+               .omega = 55,
+               .privkey_len = 4032,
+               .sig_len = 3309,
+       },
+       {
+               .type = KEY_ML_DSA_87,
+               .k = 8,
+               .l = 7,
+               .eta = 2,
+               .d = 3,
+               .gamma1_exp = 19,
+               .gamma2 = (ML_DSA_Q - 1) / 32,
+               .gamma2_d = 4,
+               .lambda = 256,
+               .tau = 60,
+               .beta = 120,
+               .omega = 75,
+               .privkey_len = 4896,
+               .sig_len = 4627,
+
+       },
+};
+
+/*
+ * Described in header
+ */
+const ml_dsa_params_t *ml_dsa_params_get(key_type_t type)
+{
+       int i;
+
+       for (i = 0; i < countof(ml_dsa_params); i++)
+       {
+               if (ml_dsa_params[i].type == type)
+               {
+                       return &ml_dsa_params[i];
+               }
+       }
+       return NULL;
+}
diff --git a/src/libstrongswan/plugins/ml/ml_dsa_params.h b/src/libstrongswan/plugins/ml/ml_dsa_params.h
new file mode 100644 (file)
index 0000000..a23d143
--- /dev/null
@@ -0,0 +1,173 @@
+/*
+ * Copyright (C) 2024 Andreas Steffen
+ *
+ * Copyright (C) secunet Security Networks AG
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * for more details.
+ */
+
+/**
+ * @defgroup ml_dsa_params ml_dsa_params
+ * @{ @ingroup ml_p
+ */
+
+#ifndef ML_PARAMS_H_
+#define ML_PARAMS_H_
+
+#include <credentials/keys/public_key.h>
+
+typedef struct ml_dsa_params_t ml_dsa_params_t;
+
+/**
+ * Constant N used throughout the algorithms.
+ */
+#define ML_DSA_N  256
+
+/**
+ * The prime q = 2^23 - 2^13 + 1.
+ */
+#define ML_DSA_Q  8380417
+
+/**
+ * Number of bits representing (q - 1).
+ */
+#define ML_DSA_Q_BITS  23
+
+/**
+ * The inverse of q mod 2^32.
+ */
+#define ML_DSA_QINV  58728449
+
+/**
+ * Dropped bits from vector t -> (t0, t1)
+ */
+#define ML_DSA_D  13
+
+/**
+ * Number of bits representing element of vector t1.
+ */
+#define ML_DSA_T1_BITS  ML_DSA_Q_BITS - ML_DSA_D
+
+/**
+ * Length of the secret seed, rho, and K
+ */
+#define ML_DSA_SEED_LEN  32
+
+/**
+ * Length of K.
+ */
+#define ML_DSA_K_LEN  32
+
+/**
+ * Length of the public key digest tr.
+ */
+#define ML_DSA_TR_LEN  64
+
+/**
+ * Length of the message representative mu.
+ */
+#define ML_DSA_MU_LEN  64
+
+/**
+ * Length of rnd used for a randomized signature.
+ */
+#define ML_DSA_RND_LEN  32
+
+/**
+ * Length of the random private seed rho_pp.
+ */
+#define ML_DSA_RHO_PP_LEN  64
+
+/**
+ * Parameters for ML-DSA.
+ */
+struct ml_dsa_params_t {
+
+       /**
+        * Key type.
+        */
+       const key_type_t type;
+
+       /**
+        * Number of lines in matrix A.
+        */
+       uint8_t k;
+
+       /**
+        * Number of columns in matrix A.
+        */
+       uint8_t l;
+
+       /**
+        * Private key range.
+        */
+       uint8_t eta;
+
+       /**
+        * Number of bits of a compressesd s1/s2 polynomial coefficient.
+        */
+       uint8_t d;
+
+       /**
+        * Power of two exponent of gamma1.
+        */
+       u_int gamma1_exp;
+
+       /**
+        * Low-order rounding range.
+        */
+       int32_t gamma2;
+
+       /**
+        * Number of bits of a compressed w1 polynomial coefficient.
+        */
+       size_t gamma2_d;
+
+       /**
+        * Collision strength of c_tilde.
+        */
+       size_t lambda;
+
+       /**
+        * Hamming weight.
+        */
+       int32_t tau;
+
+       /**
+        * beta = eta * tau.
+        */
+       int32_t beta;
+
+       /**
+        * Maximum number of 1's in the hint h.
+        */
+       int32_t omega;
+
+       /**
+        * Private key length in bytes
+        */
+       size_t privkey_len;
+
+       /**
+        * Signature length in bytes
+        */
+       size_t sig_len;
+};
+
+/**
+ * Get parameters from a specific ML-DSA algorithm.
+ *
+ * @param                              type of key
+ * @return                             parameters, NULL if not supported
+ */
+const ml_dsa_params_t *ml_dsa_params_get(key_type_t type);
+
+#endif /** ML_PARAMS_H_ @}*/
diff --git a/src/libstrongswan/plugins/ml/ml_dsa_poly.c b/src/libstrongswan/plugins/ml/ml_dsa_poly.c
new file mode 100644 (file)
index 0000000..7524aea
--- /dev/null
@@ -0,0 +1,406 @@
+/*
+ * Copyright (C) 2024 Andreas Steffen
+ *
+ * Copyright (C) secunet Security Networks AG
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * for more details.
+ */
+
+#include "ml_dsa_poly.h"
+#include "ml_utils.h"
+
+/**
+ * Precalculated Zeta^BitRev_8(i) mod q values for NTT Algorithms 41 and 42
+ * The values are in (centered) Montgomery form, not the verbatim values of
+ * Appendix B in FIPS 204.
+ */
+static const int32_t ml_dsa_zetas[ML_DSA_N] = {
+         0,    25847, -2608894,  -518909,   237124,  -777960,  -876248,   466468,
+   1826347,  2353451,  -359251, -2091905,  3119733, -2884855,  3111497,  2680103,
+   2725464,  1024112, -1079900,  3585928,  -549488, -1119584,  2619752, -2108549,
+  -2118186, -3859737, -1399561, -3277672,  1757237,   -19422,  4010497,   280005,
+   2706023,    95776,  3077325,  3530437, -1661693, -3592148, -2537516,  3915439,
+  -3861115, -3043716,  3574422, -2867647,  3539968,  -300467,  2348700,  -539299,
+  -1699267, -1643818,  3505694, -3821735,  3507263, -2140649, -1600420,  3699596,
+    811944,   531354,   954230,  3881043,  3900724, -2556880,  2071892, -2797779,
+  -3930395, -1528703, -3677745, -3041255, -1452451,  3475950,  2176455, -1585221,
+  -1257611,  1939314, -4083598, -1000202, -3190144, -3157330, -3632928,   126922,
+   3412210,  -983419,  2147896,  2715295, -2967645, -3693493,  -411027, -2477047,
+   -671102, -1228525,   -22981, -1308169,  -381987,  1349076,  1852771, -1430430,
+  -3343383,   264944,   508951,  3097992,    44288, -1100098,   904516,  3958618,
+  -3724342,    -8578,  1653064, -3249728,  2389356,  -210977,   759969, -1316856,
+    189548, -3553272,  3159746, -1851402, -2409325,  -177440,  1315589,  1341330,
+   1285669, -1584928,  -812732, -1439742, -3019102, -3881060, -3628969,  3839961,
+   2091667,  3407706,  2316500,  3817976, -3342478,  2244091, -2446433, -3562462,
+    266997,  2434439, -1235728,  3513181, -3520352, -3759364, -1197226, -3193378,
+    900702,  1859098,   909542,   819034,   495491, -1613174,   -43260,  -522500,
+   -655327, -3122442,  2031748,  3207046, -3556995,  -525098,  -768622, -3595838,
+    342297,   286988, -2437823,  4108315,  3437287, -3342277,  1735879,   203044,
+   2842341,  2691481, -2590150,  1265009,  4055324,  1247620,  2486353,  1595974,
+  -3767016,  1250494,  2635921, -3548272, -2994039,  1869119,  1903435, -1050970,
+  -1333058,  1237275, -3318210, -1430225,  -451100,  1312455,  3306115, -1962642,
+  -1279661,  1917081, -2546312, -1374803,  1500165,   777191,  2235880,  3406031,
+   -542412, -2831860, -1671176, -1846953, -2584293, -3724270,   594136, -3776993,
+  -2013608,  2432395,  2454455,  -164721,  1957272,  3369112,   185531, -1207385,
+  -3183426,   162844,  1616392,  3014001,   810149,  1652634, -3694233, -1799107,
+  -3038916,  3523897,  3866901,   269760,  2213111,  -975884,  1717735,   472078,
+   -426683,  1723600, -1803090,  1910376, -1667432, -1104333,  -260646, -3833893,
+  -2939036, -2235985,  -420899, -2286327,   183443,  -976891,  1612842, -3545687,
+   -554416,  3919660,   -48306, -1362209,  3937738,  1400424,  -846154,  1976782
+};
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_ntt(ml_dsa_poly_t *a)
+{
+       u_int len, start, j, k = 0;
+       int32_t zeta, t;
+
+       for (len = ML_DSA_N/2; len > 0; len >>= 1)
+       {
+               for (start = 0; start < ML_DSA_N; start = j + len)
+               {
+                       zeta = ml_dsa_zetas[++k];
+                       for(j = start; j < start + len; ++j)
+                       {
+                               t = ml_montgomery_reduce((int64_t)zeta * a->f[j + len]);
+                               a->f[j + len] = a->f[j] - t;
+                               a->f[j] = a->f[j] + t;
+                       }
+               }
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_inv_ntt(ml_dsa_poly_t *a)
+{
+       u_int start, len, j, k = ML_DSA_N;
+       int32_t t, zeta;
+
+       /* scaling factor 256^-1 mod q with squared Montgomery multiplier to
+        * implicitly convert results to Montgomery form (i.e. 2^64/256 mod q)
+        */
+       const int32_t factor = 41978;
+
+       for (len = 1; len < ML_DSA_N; len <<= 1)
+       {
+               for (start = 0; start < ML_DSA_N; start = j + len)
+               {
+                       zeta = -ml_dsa_zetas[--k];
+                       for (j = start; j < start + len; ++j)
+                       {
+                               t = a->f[j];
+                               a->f[j] = t + a->f[j + len];
+                               a->f[j + len] = t - a->f[j + len];
+                               a->f[j + len] = ml_montgomery_reduce((int64_t)zeta * a->f[j + len]);
+                       }
+               }
+       }
+
+       for(j = 0; j < ML_DSA_N; ++j)
+       {
+               a->f[j] = ml_montgomery_reduce((int64_t)factor * a->f[j]);
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_ntt_vec(u_int k, ml_dsa_poly_t *a)
+{
+       while (k--)
+       {
+               ml_dsa_poly_ntt(&a[k]);
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_inv_ntt_vec(u_int k, ml_dsa_poly_t *a)
+{
+       while (k--)
+       {
+               ml_dsa_poly_inv_ntt(&a[k]);
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_copy_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *b)
+{
+       while (k--)
+       {
+               b[k] = a[k];
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_add_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *b,
+                                                ml_dsa_poly_t *res)
+{
+       u_int n;
+
+       while (k--)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       res[k].f[n] = a[k].f[n] + b[k].f[n];
+               }
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_sub_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *b,
+                                                ml_dsa_poly_t *res)
+{
+       u_int n;
+
+       while (k--)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       res[k].f[n] = a[k].f[n] - b[k].f[n];
+               }
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_mult_const_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *b,
+                                                               ml_dsa_poly_t *res)
+{
+       u_int n;
+
+       /* pointwise product of polynomial vector b with a polynomial a */
+       while (k--)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       res[k].f[n] = ml_montgomery_reduce((int64_t)a->f[n] * b[k].f[n]);
+               }
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_mult_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *b,
+                                                 ml_dsa_poly_t *res)
+{
+       u_int n;
+
+       /* initialize result polynomial to all zeros */
+       for (n = 0; n < ML_DSA_N; n++)
+       {
+               res->f[n] = 0;
+       }
+
+       /* compute the inner product of vectors a and b */
+       while (k--)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       res->f[n] += ml_montgomery_reduce((int64_t)a[k].f[n] * b[k].f[n]);
+               }
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_mult_mat(u_int k, u_int l, ml_dsa_poly_t *a, ml_dsa_poly_t *b,
+                                                 ml_dsa_poly_t *res)
+{
+       u_int i;
+
+       for (i = 0; i < k; i++)
+       {
+               ml_dsa_poly_mult_vec(l, &a[i*l], b, &res[i]);
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_reduce_vec(u_int k, ml_dsa_poly_t *a)
+{
+       int32_t r;
+       u_int n;
+
+       while (k--)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       r = (a[k].f[n] + (1 << 22)) >> 23;
+                       a[k].f[n] -= r * ML_DSA_Q;
+               }
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_cond_add_q_vec(u_int k, ml_dsa_poly_t *a)
+{
+       u_int n;
+
+       while (k--)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       a[k].f[n] += (a[k].f[n] >> 31) & ML_DSA_Q;
+               }
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_power2round_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *a0,
+                                                                ml_dsa_poly_t *a1)
+{
+       int32_t t0, t1;
+       u_int n;
+
+       while (k--)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       t1 = (a[k].f[n] + (1 << (ML_DSA_D-1)) - 1) >> ML_DSA_D;
+                       t0 =  a[k].f[n] - (t1 << ML_DSA_D);
+
+                       a0[k].f[n] = t0;
+                       a1[k].f[n] = t1;
+               }
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_shift_left_vec(u_int k, ml_dsa_poly_t *a)
+{
+       u_int n;
+
+       while (k--)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       a[k].f[n] <<= ML_DSA_D;
+               }
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_decompose_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *a0,
+                                                          ml_dsa_poly_t *a1, int32_t gamma2)
+{
+       u_int n;
+
+       while (k--)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       ml_decompose(a[k].f[n], &a0[k].f[n], &a1[k].f[n], gamma2);
+               }
+       }
+}
+
+/*
+ * Described in header
+ */
+void ml_dsa_poly_use_hint_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *h,
+                                                         ml_dsa_poly_t *a1, int32_t gamma2)
+{
+       u_int n;
+
+       while (k--)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       a1[k].f[n] = ml_use_hint(a[k].f[n], h[k].f[n], gamma2);
+               }
+       }
+}
+
+/*
+ * Described in header
+ */
+u_int ml_dsa_poly_make_hint_vec(u_int k, ml_dsa_poly_t *a0, ml_dsa_poly_t *a1,
+                                                               ml_dsa_poly_t *h, int32_t gamma2)
+{
+       u_int n, s = 0;
+
+       while (k--)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       h[k].f[n] = ml_make_hint(a0[k].f[n], a1[k].f[n], gamma2);
+                       s += h[k].f[n];
+               }
+       }
+
+       return s;
+}
+
+/*
+ * Described in header
+ */
+bool ml_dsa_poly_check_bound(ml_dsa_poly_t *a, int32_t bound)
+{
+       int32_t t;
+       u_int n;
+
+  /* it is ok to leak which coefficient violates the bound since the probability
+   * for each coefficient is independent of secret data but we must not leak the
+   * sign of the centralized representative.
+   */
+       for (n = 0; n < ML_DSA_N; n++)
+       {
+               t = a->f[n] >> 31;
+               t = a->f[n] - (t & 2*a->f[n]);
+
+               if (t >= bound)
+               {
+                       return FALSE;
+               }
+       }
+
+       return TRUE;
+}
+
+/*
+ * Described in header
+ */
+bool ml_dsa_poly_check_bound_vec(u_int k, ml_dsa_poly_t *a, int32_t bound)
+{
+       while (k--)
+       {
+               if (!ml_dsa_poly_check_bound(&a[k], bound))
+               {
+                       return FALSE;
+               }
+       }
+
+       return TRUE;
+}
diff --git a/src/libstrongswan/plugins/ml/ml_dsa_poly.h b/src/libstrongswan/plugins/ml/ml_dsa_poly.h
new file mode 100644 (file)
index 0000000..06aa8c5
--- /dev/null
@@ -0,0 +1,240 @@
+/*
+ * Copyright (C) 2024 Andreas Steffen
+ *
+ * Copyright (C) secunet Security Networks AG
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * for more details.
+ */
+
+/**
+ * @defgroup ml_dsa_poly ml_dsa_poly
+ * @{ @ingroup ml_p
+ */
+
+#ifndef ML_DSA_POLY_H_
+#define ML_DSA_POLY_H_
+
+#include "ml_dsa_params.h"
+
+typedef struct ml_dsa_poly_t ml_dsa_poly_t;
+
+/**
+ * Represents an element in R_q = Z_q[X]/(X^n + 1) i.e. a polynomial of the
+ * form f[0] + f[1]*X + ... + f[n-1]*X^n-1.
+ */
+struct ml_dsa_poly_t {
+
+       /**
+        * Coefficients of the polynomial.
+        */
+       int32_t f[ML_DSA_N];
+};
+
+/**
+ * Computes the NTT.
+ *
+ * Algorithm 41 in FIPS 204.
+ *
+ * @param a            polynomial a (in-place NTT computation)
+ */
+void ml_dsa_poly_ntt(ml_dsa_poly_t *a);
+
+/**
+ * Computes the inverse NTT including scaling.
+ *
+ * Algorithm 42 in FIPS 204.
+ *
+ * @param a            polynomial a (in-place NTT computation)
+ */
+void ml_dsa_poly_inv_ntt(ml_dsa_poly_t *a);
+
+/**
+ * Computes the NTT of each vector element.
+ *
+ * @param k            vector size
+ * @param a            vector of polynomials a (in-place NTT computation)
+ */
+void ml_dsa_poly_ntt_vec(u_int k, ml_dsa_poly_t *a);
+
+/**
+ * Computes the inverse NTT of each vector element.
+ *
+ * @param k            vector size
+ * @param a            vector of polynomials a (in-place NTT computation)
+ */
+void ml_dsa_poly_inv_ntt_vec(u_int k, ml_dsa_poly_t *a);
+
+/*
+ * Copy a polynomial vector.
+ *
+ * @param k            vector size
+ * @param a            vector of polynomials a
+ * @param b            vector of polynomials b
+ */
+void ml_dsa_poly_copy_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *b);
+
+/**
+ * Add polynomials in vector a and b (a[i] + b[i] mod q for i in 0 to k-1).
+ *
+ * @param k            vector size
+ * @param a            vector of polynomials a
+ * @param b            vector of polynomials b
+ * @param res  vector of resulting polynomials (can be one of the others)
+ */
+void ml_dsa_poly_add_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *b,
+                                                ml_dsa_poly_t *res);
+
+/**
+ * Subtract polynomials in vector a and b (a[i] - b[i] mod q for i in 0 to k-1).
+ *
+ * @param k            vector size
+ * @param a            vector of polynomials a
+ * @param b            vector of polynomials b
+ * @param res  vector of resulting polynomials (can be one of the others)
+ */
+void ml_dsa_poly_sub_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *b,
+                                                ml_dsa_poly_t *res);
+
+/**
+ * Pointwise product of a polynomial vector b with a polynomial a.
+ *
+ * @param k            vector size
+ * @param a            polynomial a
+ * @param b            vector of polynomials b
+ * @param res  result vector of polynomials
+ */
+void ml_dsa_poly_mult_const_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *b,
+                                                               ml_dsa_poly_t *res);
+
+/**
+ * Dot product of two polynomial vectors a and b.
+ *
+ * @param k            vector size
+ * @param a            vector of polynomials a
+ * @param b            vector of polynomials b
+ * @param res  result polynomial
+ */
+void ml_dsa_poly_mult_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *b,
+                                                 ml_dsa_poly_t *res);
+
+/**
+ * Dot product of a matrix a with a vector b.
+ *
+ * @param k            number of lines in matrix a and size of vector res
+ * @param l            number of columns in matrix a and size of vector b
+ * @param a            kxl matrix a
+ * @param b            vector of polynomials b
+ * @param res  result vector of polynomials
+ */
+void ml_dsa_poly_mult_mat(u_int k, u_int l, ml_dsa_poly_t *a, ml_dsa_poly_t *b,
+                                                 ml_dsa_poly_t *res);
+
+/**
+ * Computes r = a mod q such that -6283008 <= r <= 6283008 ((2^31-2^22-1) mod q).
+ *
+ * @param k            vector size
+ * @param a            vector of polynomials a
+ */
+void ml_dsa_poly_reduce_vec(u_int k, ml_dsa_poly_t *a);
+
+/**
+ * Add q if coefficient is negative.
+ *
+ * @param k            vector size
+ * @param a            vector of polynomials a
+ */
+void ml_dsa_poly_cond_add_q_vec(u_int k, ml_dsa_poly_t *a);
+
+/**
+ * Decomposes a into (a1, a0) such that a ≡ a1 * 2^d + a0 mod q.
+ *
+ * Algorithm 35 of FIPS 204.
+ *
+ * @param k            vector size
+ * @param a            vector of polynomials a
+ * @param a0   vector of polynomials a0
+ * @param a1   vector of polynomials a1
+ */
+void ml_dsa_poly_power2round_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *a0,
+                                                                ml_dsa_poly_t *a1);
+
+/**
+ * Multiply polynomial by 2^d.
+ *
+ * @param k            vector size
+ * @param a            vector of polynomials a
+ */
+void ml_dsa_poly_shift_left_vec(u_int k, ml_dsa_poly_t *a);
+
+/**
+ * Decomposes a into (a1, a0) such that a ≡ a1 * (2 * gamma2) + a0 mod q.
+ *
+ * Algorithm 36 of FIPS 204.
+ *
+ * @param k                    vector size
+ * @param a                    vector of polynomials a
+ * @param a0           vector of polynomials a0
+ * @param a1           vector of polynomials a1
+ * @param gamma2       parameter gamma2
+*/
+void ml_dsa_poly_decompose_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *a0,
+                                                          ml_dsa_poly_t *a1, int32_t gamma2);
+
+/**
+ * Return the high bits a1 of a adjusted according to hint h.
+ *
+ * Algorithm 40 of FIPS 204.
+ *
+ * @param k                    vector size
+ * @param a                    vector of polynomials a
+ * @param h                    vector of polynomials h
+ * @param a1           vector of polynomials a1
+ * @param gamma2       parameter gamma2
+*/
+void ml_dsa_poly_use_hint_vec(u_int k, ml_dsa_poly_t *a, ml_dsa_poly_t *h,
+                                                         ml_dsa_poly_t *a1, int32_t gamma2);
+
+/**
+ * Compute a hint bit indicating whether the low bits a0 of the
+ * input element overflow into the high bits a1.
+ *
+ * Algorithm 39 in FIPS 204.
+ *
+ * @param k                    vector size
+ * @param a0           vector of polynomials containg low bits
+ * @param a1           vector of polynomials containing high bits
+ * @param h                    vector of polynomials containing hint bits
+ * @param gamma2       parameter gamma2
+ * @return                     total numer of hint bits
+ */
+u_int ml_dsa_poly_make_hint_vec(u_int k, ml_dsa_poly_t *a0, ml_dsa_poly_t *a1,
+                                                               ml_dsa_poly_t *h, int32_t gamma2);
+
+/**
+ * Check infinity norm of polynomial against given bound.
+ *
+ * @param a            polynomial a
+ * @param bound        norm bound
+ * @return             TRUE if bound is not exceeded
+ */
+bool ml_dsa_poly_check_bound(ml_dsa_poly_t *a, int32_t bound);
+
+/**
+ * Check infinity norm of vector of polynomials against given bound.
+ *
+ * @param k            vector size
+ * @param a            vector if polynomials a
+ * @param bound        norm bound
+ * @return             TRUE if bound is not exceeded
+ */
+bool ml_dsa_poly_check_bound_vec(u_int k, ml_dsa_poly_t *a, int32_t bound);
+
+#endif /** ML_DSA_POLY_H_ @}*/
diff --git a/src/libstrongswan/plugins/ml/ml_dsa_private_key.c b/src/libstrongswan/plugins/ml/ml_dsa_private_key.c
new file mode 100644 (file)
index 0000000..359ce4e
--- /dev/null
@@ -0,0 +1,1025 @@
+/*
+ * Copyright (C) 2024-2025 Andreas Steffen
+ *
+ * Copyright (C) secunet Security Networks AG
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * for more details.
+ */
+
+#include "ml_dsa_private_key.h"
+#include "ml_dsa_params.h"
+#include "ml_dsa_poly.h"
+#include "ml_utils.h"
+#include "ml_bitpacker.h"
+
+#include <library.h>
+#include <utils/debug.h>
+#include <asn1/asn1.h>
+#include <credentials/cred_encoding.h>
+#include <credentials/keys/public_key.h>
+#include <credentials/keys/signature_params.h>
+
+typedef struct private_private_key_t private_private_key_t;
+
+/**
+ * Private data
+ */
+struct private_private_key_t {
+
+       /**
+        * Public interface
+        */
+       private_key_t public;
+
+       /**
+        * Key type
+        */
+       key_type_t type;
+
+       /**
+        * Parameter set.
+        */
+       const ml_dsa_params_t *params;
+
+       /**
+        * Secret key seed
+        */
+       chunk_t keyseed;
+
+       /**
+        * Public key
+        */
+       chunk_t pubkey;
+
+       /**
+        * Private key
+        */
+       chunk_t privkey;
+
+       /**
+        * SHAKE-128 instance.
+        */
+       xof_t *G;
+
+       /**
+        * SHAKE-256 instance.
+        */
+       xof_t *H;
+
+       /**
+        * Reference count
+        */
+       refcount_t ref;
+};
+
+/* from ml_dsa_public_key.c */
+bool ml_dsa_expand_a(const ml_dsa_params_t *params, xof_t *G, chunk_t rho,
+                                               ml_dsa_poly_t *a);
+
+bool ml_dsa_sample_in_ball(xof_t *H, int32_t tau, chunk_t rho, ml_dsa_poly_t *c);
+
+bool ml_dsa_w1_encode(u_int k, ml_dsa_poly_t *w1, chunk_t w1_enc, u_int d);
+
+bool ml_dsa_fingerprint(chunk_t pubkey, key_type_t type,
+                                               cred_encoding_type_t enc_type, chunk_t *fp);
+
+bool ml_dsa_type_supported(key_type_t type);
+
+/**
+ * Decode the secret key (skDecode)
+ *
+ * Algorithm 25 in FIPS 204.
+ */
+static bool decode_secret_key(private_private_key_t *this, chunk_t rho,
+                                                         chunk_t K, chunk_t tr, ml_dsa_poly_t *s1,
+                                                         ml_dsa_poly_t *s2, ml_dsa_poly_t *t0)
+{
+       const u_int k = this->params->k;
+       const u_int l = this->params->l;
+       const u_int d = this->params->d;
+       const u_int eta = this->params->eta;
+       u_int i, j, n;
+       uint32_t value;
+       ml_bitpacker_t *bitpacker;
+       chunk_t sk;
+       bool success = FALSE;
+
+       sk = this->privkey;
+       memcpy(rho.ptr, sk.ptr, rho.len);
+       sk = chunk_skip(sk, rho.len);
+
+       memcpy(K.ptr, sk.ptr, K.len);
+       sk = chunk_skip(sk, K.len);
+
+       memcpy(tr.ptr, sk.ptr, tr.len);
+       sk = chunk_skip(sk, tr.len);
+
+       /* unpack the vectors s1, s2 and t0 from the private key blob */
+       bitpacker = ml_bitpacker_create_from_data(sk);
+       for (j = 0; j < l; j++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->read_bits(bitpacker, &value, d) || value > 2*eta)
+                       {
+                               goto end;
+                       }
+                       s1[j].f[n] = eta - (int32_t)value;
+               }
+       }
+       for (i = 0; i < k; i++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->read_bits(bitpacker, &value, d) || value > 2*eta)
+                       {
+                               goto end;
+                       }
+                       s2[i].f[n] = eta - (int32_t)value;
+               }
+       }
+       for (i = 0; i < k; i++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->read_bits(bitpacker, &value, ML_DSA_D))
+                       {
+                               goto end;
+                       }
+                       t0[i].f[n] = (1 << (ML_DSA_D-1)) - (int32_t)value;
+               }
+       }
+       success = TRUE;
+
+end:
+       bitpacker->destroy(bitpacker);
+
+       return success;
+}
+
+/**
+ * Samples a vector y such that each polynomial has coefficients between
+ * -gamma1 + 1 and gamma1.
+ *
+ * Algorithm 34 in FIPS 204.
+ */
+static bool expand_mask(private_private_key_t *this, chunk_t rho, u_int nonce,
+                                               ml_dsa_poly_t *y)
+{
+       const u_int gamma1_exp = this->params->gamma1_exp;
+       const u_int l = this->params->l;
+       ml_bitpacker_t *bitpacker;
+       chunk_t v;
+       uint32_t value;
+       u_int j, n;
+
+       v = chunk_alloca(32*(1 + gamma1_exp));
+
+       for (j = 0; j < l; j++)
+       {
+               rho.ptr[ML_DSA_RHO_PP_LEN]   = nonce & 0x00ff;
+               rho.ptr[ML_DSA_RHO_PP_LEN+1] = nonce++ >> 8;
+
+               if (!this->H->set_seed(this->H, rho) ||
+                       !this->H->get_bytes(this->H, v.len, v.ptr))
+               {
+                       return FALSE;
+               }
+
+               bitpacker = ml_bitpacker_create_from_data(v);
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->read_bits(bitpacker, &value, 1 + gamma1_exp))
+                       {
+                               bitpacker->destroy(bitpacker);
+                               return FALSE;
+                       }
+                       y[j].f[n] = (1 << gamma1_exp) - (int32_t)value;
+               }
+               bitpacker->destroy(bitpacker);
+       }
+
+       return TRUE;
+}
+
+/**
+ * Encodes a signature into a byte string (sigEncode).
+ *
+ * Algorithm 26 in FIPS 204.
+ */
+static chunk_t encode_signature(private_private_key_t *this, chunk_t c_tilde,
+                                                               ml_dsa_poly_t *z, ml_dsa_poly_t *h)
+{
+       const u_int k = this->params->k;
+       const u_int l = this->params->l;
+       const u_int gamma1_exp = this->params->gamma1_exp;
+       const u_int omega = this->params->omega;
+       ml_bitpacker_t *bitpacker;
+       chunk_t signature, sig;
+       u_int i, j, n, index = 0;
+
+       signature = chunk_alloc(this->params->sig_len);
+       sig = signature;
+
+       /* encode byte string c_tilde */
+       memcpy(sig.ptr, c_tilde.ptr, c_tilde.len);
+       sig = chunk_skip(sig, c_tilde.len);
+
+       /* encode vector of polynomials z in packed format */
+       bitpacker = ml_bitpacker_create(sig);
+       for (j = 0; j < l; j++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->write_bits(bitpacker, (1 << gamma1_exp) - z[j].f[n],
+                                                                                                  1 +  gamma1_exp))
+                       {
+                               bitpacker->destroy(bitpacker);
+                               chunk_free(&signature);
+                               return chunk_empty;
+                       }
+               }
+       }
+       bitpacker->destroy(bitpacker);
+       sig = chunk_skip(sig, 32 * (1 + gamma1_exp) * l);
+
+       /* encode vector of polynomials with binary coefficients h (HintBitPack)
+        * Algorithm 20 in FIPS 204.
+        */
+       memset(sig.ptr, 0x00, omega);
+
+       for (i = 0; i < k; i++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (h[i].f[n] != 0)
+                       {
+                               sig.ptr[index++] = (uint8_t)n;
+                       }
+               }
+               sig.ptr[omega + i] = (uint8_t)index;
+       }
+
+       return signature;
+}
+
+METHOD(private_key_t, sign, bool,
+       private_private_key_t *this, signature_scheme_t scheme,
+       void *params, chunk_t data, chunk_t *signature)
+{
+
+       const u_int k = this->params->k;
+       const u_int l = this->params->l;
+       const u_int gamma2 = this->params->gamma2;
+       const u_int omega = this->params->omega;
+       const int32_t bound1 = (1 << this->params->gamma1_exp) - this->params->beta;
+       const int32_t bound2 = gamma2 - this->params->beta;
+       pqc_params_t pqc_params;
+       ml_dsa_poly_t a[k*l], s1[l], s2[k], t0[k];
+       ml_dsa_poly_t y[l], z[l], w0[k], w1[k], h[k], c;
+       chunk_t rho, K , tr, seed, mu, rnd, rho_pp, w1_enc, c_tilde;
+       u_int kappa = 0;
+       rng_t *rng;
+
+       /* set empty signature in case of failure */
+       *signature = chunk_empty;
+
+       if (key_type_from_signature_scheme(scheme) != this->type)
+       {
+               DBG1(DBG_LIB, "signature scheme %N not supported",
+                                          signature_scheme_names, scheme);
+               return FALSE;
+       }
+
+       rho     = chunk_alloca(ML_DSA_SEED_LEN);
+       K       = chunk_alloca(ML_DSA_K_LEN);
+       tr      = chunk_alloca(ML_DSA_TR_LEN);
+       rnd     = chunk_alloca(ML_DSA_RND_LEN);
+       mu      = chunk_alloca(ML_DSA_MU_LEN);
+       rho_pp  = chunk_alloca(ML_DSA_RHO_PP_LEN + 2);
+       w1_enc  = chunk_alloca(32 * this->params->gamma2_d * k);
+       c_tilde = chunk_alloca(this->params->lambda / 4);
+
+       if (!decode_secret_key(this, rho, K, tr, s1, s2, t0) ||
+               !ml_dsa_expand_a(this->params, this->G, rho, a))
+       {
+               goto cleanup;
+       }
+
+       ml_dsa_poly_ntt_vec(l, s1);
+       ml_dsa_poly_ntt_vec(k, s2);
+       ml_dsa_poly_ntt_vec(k, t0);
+
+       /* set PQC signature params */
+       if (!pqc_params_create(params, &pqc_params))
+       {
+               goto cleanup;
+       }
+
+       /* compute message representative mu */
+       seed = chunk_cat("cmc", tr, pqc_params.pre_ctx, data);
+       if (!this->H->set_seed(this->H, seed) ||
+               !this->H->get_bytes(this->H, ML_DSA_MU_LEN, mu.ptr))
+       {
+               chunk_free(&seed);
+               goto cleanup;
+       }
+       chunk_free(&seed);
+
+
+       /* deterministic or randomized signature? */
+       if      (pqc_params.deterministic)
+       {
+               memset(rnd.ptr, 0x00, ML_DSA_RND_LEN);
+       }
+       else
+       {
+               rng = lib->crypto->create_rng(lib->crypto, RNG_STRONG);
+               if (!rng || !rng->get_bytes(rng, ML_DSA_RND_LEN, rnd.ptr))
+               {
+                       DESTROY_IF(rng);
+                       goto cleanup;
+               }
+               rng->destroy(rng);
+       }
+
+       /* compute random private seed rho_pp */
+       seed = chunk_cat("ccc", K, rnd, mu);
+       if (!this->H->set_seed(this->H, seed) ||
+               !this->H->get_bytes(this->H, ML_DSA_RHO_PP_LEN, rho_pp.ptr))
+       {
+               chunk_clear(&seed);
+               goto cleanup;
+       }
+       chunk_clear(&seed);
+
+       while (TRUE)
+       {
+               if (!expand_mask(this, rho_pp, kappa, y))
+               {
+                       goto cleanup;
+               }
+               kappa += l;
+
+               /* multiply vector y with matrix a via NTT resulting in vector w */
+               ml_dsa_poly_copy_vec(l, y, z);
+               ml_dsa_poly_ntt_vec(l, z);
+               ml_dsa_poly_mult_mat(k, l, a, z, w1);
+               ml_dsa_poly_reduce_vec(k, w1);
+               ml_dsa_poly_inv_ntt_vec(k, w1);
+               ml_dsa_poly_cond_add_q_vec(k, w1);
+
+               /* decompose elements of vector w into high and low bits */
+               ml_dsa_poly_decompose_vec(k, w1, w0, w1, gamma2);
+
+               /* compress the w1 vector into a byte string */
+               if (!ml_dsa_w1_encode(k, w1, w1_enc, this->params->gamma2_d))
+               {
+                       goto cleanup;
+               }
+
+               /* compute commitment hash c_tilde */
+               seed = chunk_cat("cc", mu, w1_enc);
+               if (!this->H->set_seed(this->H, seed) ||
+                       !this->H->get_bytes(this->H, this->params->lambda/4, c_tilde.ptr))
+               {
+                       chunk_clear(&seed);
+                       goto cleanup;
+               }
+               chunk_clear(&seed);
+
+               /* verifier's challenge */
+               if (!ml_dsa_sample_in_ball(this->H, this->params->tau, c_tilde, &c))
+               {
+                       goto cleanup;
+               }
+               ml_dsa_poly_ntt(&c);
+
+               /* compute z, reject if it reveals secret */
+               ml_dsa_poly_mult_const_vec(l, &c, s1, z);
+               ml_dsa_poly_inv_ntt_vec(l, z);
+               ml_dsa_poly_add_vec(l, z, y, z);
+               ml_dsa_poly_reduce_vec(l, z);
+
+               if (!ml_dsa_poly_check_bound_vec(l, z, bound1))
+               {
+                       continue;
+               }
+
+               /* check that subtracting cs2 does not change high bits of w and
+                * low bits do not reveal secret information
+                */
+               ml_dsa_poly_mult_const_vec(k, &c, s2, h);
+               ml_dsa_poly_inv_ntt_vec(k, h);
+               ml_dsa_poly_sub_vec(k, w0, h, w0);
+               ml_dsa_poly_reduce_vec(k, w0);
+
+               if (!ml_dsa_poly_check_bound_vec(k, w0, bound2))
+               {
+                       continue;
+               }
+
+               /* compute hints for w1 */
+               ml_dsa_poly_mult_const_vec(k, &c, t0, h);
+               ml_dsa_poly_inv_ntt_vec(k, h);
+               ml_dsa_poly_reduce_vec(k, h);
+
+               if (!ml_dsa_poly_check_bound_vec(k, h, gamma2))
+               {
+                       continue;
+               }
+
+               ml_dsa_poly_add_vec(k, w0, h, w0);
+
+               if (ml_dsa_poly_make_hint_vec(k, w0, w1, h, gamma2) > omega)
+               {
+                       continue;
+               }
+
+               /* all checks passed - exit the loop */
+               break;
+       }
+
+       *signature = encode_signature(this, c_tilde, z, h);
+
+cleanup:
+       memwipe(a, sizeof(a));
+       memwipe(s1, sizeof(s1));
+       memwipe(s2, sizeof(s2));
+       memwipe(t0, sizeof(t0));
+       memwipe(K.ptr, K.len);
+
+       return signature->len > 0;
+}
+
+METHOD(private_key_t, decrypt, bool,
+       private_private_key_t *this, encryption_scheme_t scheme,
+       void *params, chunk_t crypto, chunk_t *plain)
+{
+       DBG1(DBG_LIB, "ML-DSA private key decryption not implemented");
+       return FALSE;
+}
+
+METHOD(private_key_t, get_keysize, int,
+       private_private_key_t *this)
+{
+       return BITS_PER_BYTE * get_public_key_size(this->type);
+}
+
+METHOD(private_key_t, get_type, key_type_t,
+       private_private_key_t *this)
+{
+       return this->type;
+}
+
+METHOD(private_key_t, get_public_key, public_key_t*,
+       private_private_key_t *this)
+{
+       return lib->creds->create(lib->creds, CRED_PUBLIC_KEY, this->type,
+                                                         BUILD_BLOB, this->pubkey, BUILD_END);
+}
+
+METHOD(private_key_t, get_fingerprint, bool,
+       private_private_key_t *this, cred_encoding_type_t type, chunk_t *fp)
+{
+       bool success;
+
+       if (lib->encoding->get_cache(lib->encoding, type, this, fp))
+       {
+               return TRUE;
+       }
+
+       success = ml_dsa_fingerprint(this->pubkey, this->type, type, fp);
+       if (success)
+       {
+               lib->encoding->cache(lib->encoding, type, this, fp);
+       }
+
+       return success;
+}
+
+METHOD(private_key_t, get_encoding, bool,
+       private_private_key_t *this, cred_encoding_type_t type, chunk_t *encoding)
+{
+       switch (type)
+       {
+               case PRIVKEY_ASN1_DER:
+               case PRIVKEY_PEM:
+               {
+                       bool success = TRUE;
+                       int oid = key_type_to_oid(this->type);
+
+                       *encoding = asn1_wrap(ASN1_SEQUENCE, "cmm",
+                                                       ASN1_INTEGER_0,
+                                                       asn1_algorithmIdentifier(oid),
+                                                       asn1_wrap(ASN1_OCTET_STRING, "m",
+                                                               asn1_simple_object(ASN1_CONTEXT_S_0,
+                                                                                                  this->keyseed))
+                                               );
+                       if (type == PRIVKEY_PEM)
+                       {
+                               chunk_t asn1_encoding = *encoding;
+
+                               success = lib->encoding->encode(lib->encoding, PRIVKEY_PEM,
+                                                               NULL, encoding, CRED_PART_PRIV_ASN1_DER,
+                                                               asn1_encoding, CRED_PART_END);
+                               chunk_clear(&asn1_encoding);
+                       }
+
+                       return success;
+               }
+               default:
+                       return FALSE;
+       }
+}
+
+METHOD(private_key_t, get_ref, private_key_t*,
+       private_private_key_t *this)
+{
+       ref_get(&this->ref);
+       return &this->public;
+}
+
+METHOD(private_key_t, destroy, void,
+       private_private_key_t *this)
+{
+       if (ref_put(&this->ref))
+       {
+               lib->encoding->clear_cache(lib->encoding, this);
+               DESTROY_IF(this->G);
+               DESTROY_IF(this->H);
+               chunk_clear(&this->keyseed);
+               chunk_clear(&this->privkey);
+               chunk_free(&this->pubkey);
+               free(this);
+       }
+}
+
+/**
+ * Generates an element of [-eta, eta] or rejects the sample.
+ *
+ * Algorithm 15 in FIPS 204.
+ */
+static bool coeff_from_half_byte(uint8_t b, uint8_t eta, int32_t *a)
+{
+       const int32_t eta_samples[] = {
+               2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1, 0, -1, -2
+       };
+
+       if (eta == 2)
+       {
+               if (b >= 15)
+               {
+                       return FALSE;  /* reject sample */
+               }
+               *a = eta_samples[b];
+       }
+       else if (eta == 4)
+       {
+               if (b >= 9)
+               {
+                       return FALSE;  /* reject sample */
+               }
+               *a = 4 - (int32_t)b;
+       }
+
+       return TRUE;
+}
+
+/**
+ * Samples an element with coefficients in [-eta, eta] computed via
+ * rejection sampling on a SHAKE-256 output stream H.
+ *
+ * Algorithm 31 in FIPS 204.
+ */
+static bool rej_bounded_poly(private_private_key_t *this, chunk_t seed,
+                                                        ml_dsa_poly_t *a)
+{
+       uint8_t c, c0, c1;
+       u_int n = 0;
+
+       if (!this->H->set_seed(this->H, seed))
+       {
+               return FALSE;
+       }
+
+       while (n < ML_DSA_N)
+       {
+               if (!this->H->get_bytes(this->H, 1, &c))
+               {
+                       return FALSE;
+               }
+
+               /* form half bytes */
+               c0 = c & 0x0f;
+               c1 = c >> 4;
+               if (coeff_from_half_byte(c0, this->params->eta, &a->f[n]))
+               {
+                       if (++n == ML_DSA_N)
+                       {
+                               break;
+                       }
+               }
+               if (coeff_from_half_byte(c1, this->params->eta, &a->f[n]))
+               {
+                       ++n;
+               }
+       }
+       return TRUE;
+}
+
+/**
+ * Samples vectors s1 and s2, each with polynomial coordinates whose coefficients
+ * are in the interval [-eta, eta].
+ *
+ * Algorithm 33 in FIPS 204.
+ */
+static bool expand_s(private_private_key_t *this, chunk_t rhoprime,
+                                        ml_dsa_poly_t *s1, ml_dsa_poly_t *s2)
+{
+       chunk_t seed;
+       const u_int k = this->params->k;
+       const u_int l = this->params->l;
+       u_int i, j;
+       bool success = FALSE;
+
+       seed = chunk_alloca(2*ML_DSA_SEED_LEN + 2);
+       memcpy(seed.ptr, rhoprime.ptr, rhoprime.len);
+       seed.ptr[2*ML_DSA_SEED_LEN+1] = 0;
+
+       for (j = 0; j < l; j++)
+       {
+               seed.ptr[2*ML_DSA_SEED_LEN] = (uint8_t)j;
+           if (!rej_bounded_poly(this, seed, &s1[j]))
+           {
+                       goto cleanup;
+           }
+       }
+       for (i = 0; i < k; i++)
+       {
+               seed.ptr[2*ML_DSA_SEED_LEN] = (uint8_t)(l + i);
+           if (!rej_bounded_poly(this, seed, &s2[i]))
+           {
+                       goto cleanup;
+           }
+       }
+       success = TRUE;
+
+cleanup:
+       memwipe(seed.ptr, seed.len);
+
+       return success;
+}
+
+/**
+ * Encode the public key (pkEncode).
+ *
+ * Algorithm 22 in FIPS 204.
+ */
+static bool encode_public_key(private_private_key_t *this, chunk_t rho,
+                                                         ml_dsa_poly_t *t1)
+{
+       const u_int k = this->params->k;
+       u_int i, n;
+       ml_bitpacker_t *bitpacker;
+       chunk_t pk;
+
+       pk = this->pubkey;
+       memcpy(pk.ptr, rho.ptr, rho.len);
+       pk = chunk_skip(pk, rho.len);
+
+       /* pack the vector t1 into the public key blob */
+       bitpacker = ml_bitpacker_create(pk);
+       for (i = 0; i < k; i++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->write_bits(bitpacker, t1[i].f[n], ML_DSA_T1_BITS))
+                       {
+                               bitpacker->destroy(bitpacker);
+                               return FALSE;
+                       }
+               }
+       }
+       bitpacker->destroy(bitpacker);
+
+       return TRUE;
+}
+
+/**
+ * Encode the secret key (skEncode).
+ *
+ * Algorithm 24 in FIPS 204.
+ */
+static bool encode_secret_key(private_private_key_t *this, chunk_t rho,
+                                                         chunk_t K, ml_dsa_poly_t *s1, ml_dsa_poly_t *s2,
+                                                         ml_dsa_poly_t *t0)
+{
+       const u_int k = this->params->k;
+       const u_int l = this->params->l;
+       const u_int d = this->params->d;
+       const u_int eta = this->params->eta;
+       u_int i, j, n;
+       ml_bitpacker_t *bitpacker;
+       chunk_t sk;
+       bool success = FALSE;
+
+       sk = this->privkey;
+       memcpy(sk.ptr, rho.ptr, rho.len);
+       sk = chunk_skip(sk, rho.len);
+
+       memcpy(sk.ptr, K.ptr, K.len);
+       sk = chunk_skip(sk, K.len);
+
+       /* compute tr as a SHAKE-256 digest over the public key blob
+        * and put it in the private key blob
+        */
+       if (!this->H->set_seed(this->H, this->pubkey) ||
+               !this->H->get_bytes(this->H, ML_DSA_TR_LEN, sk.ptr))
+       {
+               return FALSE;
+       }
+       sk = chunk_skip(sk, ML_DSA_TR_LEN);
+
+       /* pack the vectors s1, s2 and t0 into the private key blob*/
+       bitpacker = ml_bitpacker_create(sk);
+       for (j = 0; j < l; j++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->write_bits(bitpacker, eta - s1[j].f[n], d))
+                       {
+                               goto end;
+                       }
+               }
+       }
+       for (i = 0; i < k; i++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->write_bits(bitpacker, eta - s2[i].f[n], d))
+                       {
+                               goto end;
+                       }
+               }
+       }
+       for (i = 0; i < k; i++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->write_bits(bitpacker, (1 << (ML_DSA_D-1)) - t0[i].f[n],
+                                                                                                                ML_DSA_D))
+                       {
+                               goto end;
+                       }
+               }
+       }
+       success = TRUE;
+
+end:
+       bitpacker->destroy(bitpacker);
+
+       return success;
+}
+
+/**
+ * Generates a public/private key pair from a seed
+ *
+ * Algorithm 6 in FIPS 204.
+ */
+static bool generate_keypair(private_private_key_t *this, chunk_t keyseed)
+{
+       const u_int k = this->params->k;
+       const u_int l = this->params->l;
+       ml_dsa_poly_t a[k*l], s1[l], s1_hat[l], s2[k], t1[k], t0[k];
+       bool success = FALSE;
+
+       /**
+        *  Mapping of seedbuf
+        *
+        *       0         32
+        *       +---------+
+        *  Init | keyseed |
+        *       +---------+
+        *       0          34
+        *       +----------+
+        *  In   |   seed   |
+        *       +----------+
+        *       0         32        64        96        128
+        *       +---------+-------------------+---------+
+        *  Out  |   rho   |      rhoprime     |    K    |
+        *       +---------+-------------------+---------+
+        */
+       uint8_t seedbuf[4*ML_DSA_SEED_LEN];
+       chunk_t seed =     { seedbuf,                       ML_DSA_SEED_LEN+2 };
+       chunk_t rho  =     { seedbuf,                       ML_DSA_SEED_LEN };
+       chunk_t rhoprime = { seedbuf +   ML_DSA_SEED_LEN, 2*ML_DSA_SEED_LEN };
+       chunk_t K =        { seedbuf + 3*ML_DSA_SEED_LEN,   ML_DSA_K_LEN };
+
+       /* keep a copy of the secret key seed */
+       this->keyseed = keyseed;
+
+       memcpy(seedbuf, keyseed.ptr, keyseed.len);
+       seedbuf[ML_DSA_SEED_LEN]   = this->params->k;
+       seedbuf[ML_DSA_SEED_LEN+1] = this->params->l;
+
+       if (!this->H->set_seed(this->H, seed) ||
+               !this->H->get_bytes(this->H, sizeof(seedbuf), seedbuf) ||
+               !ml_dsa_expand_a(this->params, this->G, rho, a) ||
+               !expand_s(this, rhoprime, s1, s2))
+       {
+               goto cleanup;
+       }
+
+       /* apply NTT to a copy of the s1 vector */
+       ml_dsa_poly_copy_vec(l, s1, s1_hat);
+       ml_dsa_poly_ntt_vec(l, s1_hat);
+
+       /* multiply vector s1_hat with matrix a in the NTT domain */
+       ml_dsa_poly_mult_mat(k, l, a, s1_hat, t1);
+
+       /* reduce the elements of vector t1 to the range -6283008 <= r <= 6283008 */
+       ml_dsa_poly_reduce_vec(k, t1);
+
+       /* apply the inverse NTT to vector t1 */
+       ml_dsa_poly_inv_ntt_vec(k, t1);
+
+       /* add error vector s2 to t1 */
+       ml_dsa_poly_add_vec(k, s2, t1, t1);
+
+       /* make all polynomial coefficients positive by conditionally adding q */
+       ml_dsa_poly_cond_add_q_vec(k, t1);
+
+       /* decomposes t1 into (t1, t0) such that t1 ≡ t1 * 2^d + t0 mod q */
+       ml_dsa_poly_power2round_vec(k, t1, t0, t1);
+
+       success = encode_public_key(this, rho, t1) &&
+                         encode_secret_key(this, rho, K, s1, s2, t0);
+
+cleanup:
+       memwipe(seedbuf, sizeof(seedbuf));
+       memwipe(a, sizeof(a));
+       memwipe(s1, sizeof(s1));
+       memwipe(s1_hat, sizeof(s1_hat));
+       memwipe(s2, sizeof(s2));
+       memwipe(t0, sizeof(t0));
+
+       return success;
+}
+
+/**
+ * Generic private constructor
+ */
+static private_private_key_t *create_instance(key_type_t type)
+{
+       private_private_key_t *this;
+       const ml_dsa_params_t *params;
+
+       params = ml_dsa_params_get(type);
+       if (!params)
+       {
+               return NULL;
+       }
+
+       INIT(this,
+               .public = {
+                       .get_type = _get_type,
+                       .sign = _sign,
+                       .decrypt = _decrypt,
+                       .get_keysize = _get_keysize,
+                       .get_public_key = _get_public_key,
+                       .equals = private_key_equals,
+                       .belongs_to = private_key_belongs_to,
+                       .get_fingerprint = _get_fingerprint,
+                       .has_fingerprint = private_key_has_fingerprint,
+                       .get_encoding = _get_encoding,
+                       .get_ref = _get_ref,
+                       .destroy = _destroy,
+               },
+               .type = type,
+               .params = params,
+               .pubkey = chunk_alloc(get_public_key_size(type)),
+               .privkey = chunk_alloc(params->privkey_len),
+               .G = lib->crypto->create_xof(lib->crypto, XOF_SHAKE_128),
+               .H = lib->crypto->create_xof(lib->crypto, XOF_SHAKE_256),
+               .ref = 1,
+       );
+
+       if (!this->G || !this->H)
+       {
+               destroy(this);
+               return NULL;
+       }
+
+       return this;
+}
+
+/*
+ * Described in header
+ */
+private_key_t *ml_dsa_private_key_gen(key_type_t type, va_list args)
+{
+       private_private_key_t *this;
+       chunk_t seed;
+       rng_t *rng;
+
+       while (TRUE)
+       {
+               switch (va_arg(args, builder_part_t))
+               {
+                       case BUILD_KEY_SIZE:
+                               /* just ignore the key size */
+                               va_arg(args, u_int);
+                               continue;
+                       case BUILD_END:
+                               break;
+                       default:
+                               return NULL;
+               }
+               break;
+       }
+
+       rng = lib->crypto->create_rng(lib->crypto, RNG_TRUE);
+       if (!rng || !rng->allocate_bytes(rng, ML_DSA_SEED_LEN, &seed))
+       {
+               DESTROY_IF(rng);
+               return NULL;
+       }
+       rng->destroy(rng);
+
+       this = create_instance(type);
+       if (!this)
+       {
+               chunk_free(&seed);
+               return NULL;
+       }
+
+       if (!generate_keypair(this, seed))
+       {
+               destroy(this);
+               return NULL;
+       }
+       return &this->public;
+}
+
+/*
+ * Described in header
+ */
+private_key_t *ml_dsa_private_key_load(key_type_t type, va_list args)
+{
+       private_private_key_t *this;
+       chunk_t priv = chunk_empty;
+
+       while (TRUE)
+       {
+               switch (va_arg(args, builder_part_t))
+               {
+                       case BUILD_BLOB:
+                               priv = va_arg(args, chunk_t);
+                               continue;
+                       case BUILD_END:
+                               break;
+                       default:
+                               return NULL;
+               }
+               break;
+       }
+
+       if (priv.len == 0 || !ml_dsa_type_supported(type))
+       {
+               return NULL;
+       }
+       if (priv.len == ML_DSA_SEED_LEN + 2 &&
+               priv.ptr[0] == 0x80 && priv.ptr[1] == ML_DSA_SEED_LEN)
+       {
+               priv = chunk_skip(priv, 2);
+       }
+       if (priv.len != ML_DSA_SEED_LEN)
+       {
+               DBG1(DBG_LIB, "error: the size of the loaded ML-DSA private key seed is "
+                                         "%u bytes instead of %d bytes", priv.len, ML_DSA_SEED_LEN);
+               return NULL;
+       }
+
+       this = create_instance(type);
+       if (!this)
+       {
+               return NULL;
+       }
+
+       if (!generate_keypair(this, chunk_clone(priv)))
+       {
+               destroy(this);
+               return NULL;
+       }
+
+       return &this->public;
+}
diff --git a/src/libstrongswan/plugins/ml/ml_dsa_private_key.h b/src/libstrongswan/plugins/ml/ml_dsa_private_key.h
new file mode 100644 (file)
index 0000000..473cecd
--- /dev/null
@@ -0,0 +1,48 @@
+/*
+ * Copyright (C) 2024 Andreas Steffen
+ *
+ * Copyright (C) secunet Security Networks AG
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * for more details.
+ */
+
+/**
+ * @defgroup ml_dsa_private_key ml_dsa_private_key
+ * @{ @ingroup ml_p
+ */
+
+#ifndef ML_DSA_PRIVATE_KEY_H_
+#define ML_DSA_PRIVATE_KEY_H_
+
+#include <credentials/builder.h>
+#include <credentials/keys/private_key.h>
+
+/**
+ * Generate an ML-DSA private key.
+ *
+ * @param type         key type, must be KEY_ML_DSA_44, KEY_ML_DSA_65 or KEY_ML_DSA_87
+ * @param args         builder_part_t argument list
+ * @return                     generated key, NULL on failure
+ */
+private_key_t *ml_dsa_private_key_gen(key_type_t type, va_list args);
+
+/**
+ * Load an ML-DSA private key using.
+ *
+ * Accepts a BUILD_BLOB argument.
+ *
+ * @param type         key type, must be KEY_ML_DSA_44, KEY_ML_DSA_65 or KEY_ML_DSA_87
+ * @param args         builder_part_t argument list
+ * @return                     loaded key, NULL on failure
+ */
+private_key_t *ml_dsa_private_key_load(key_type_t type, va_list args);
+
+#endif /** ML_DSA_PRIVATE_KEY_H_ @}*/
diff --git a/src/libstrongswan/plugins/ml/ml_dsa_public_key.c b/src/libstrongswan/plugins/ml/ml_dsa_public_key.c
new file mode 100644 (file)
index 0000000..c2bef3e
--- /dev/null
@@ -0,0 +1,654 @@
+/*
+ * Copyright (C) 2024 Andreas Steffen, strongSec GmbH
+ *
+ * Copyright (C) secunet Security Networks AG
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * for more details.
+ */
+
+#include "ml_dsa_public_key.h"
+#include "ml_dsa_params.h"
+#include "ml_dsa_poly.h"
+#include "ml_bitpacker.h"
+
+#include <utils/debug.h>
+#include <asn1/asn1.h>
+
+typedef struct private_public_key_t private_public_key_t;
+
+/**
+ * Private data
+ */
+struct private_public_key_t {
+
+       /**
+        * Public interface
+        */
+       public_key_t public;
+
+       /**
+        * Key type
+        */
+       key_type_t type;
+
+       /**
+        * Parameter set.
+        */
+       const ml_dsa_params_t *params;
+
+       /**
+        * Public key
+        */
+       chunk_t pubkey;
+
+       /**
+        * SHAKE-128 instance.
+        */
+       xof_t *G;
+
+       /**
+        * SHAKE-256 instance.
+        */
+       xof_t *H;
+
+       /**
+        * Reference count
+        */
+       refcount_t ref;
+};
+
+METHOD(public_key_t, get_type, key_type_t,
+       private_public_key_t *this)
+{
+       return this->type;
+}
+
+/**
+ * Decode the public key (pkDncode).
+ *
+ * Algorithm 23 in FIPS 204.
+ */
+static bool decode_public_key(private_public_key_t *this, chunk_t rho,
+                                                         ml_dsa_poly_t *t1)
+{
+       const u_int k = this->params->k;
+       u_int i, n;
+       ml_bitpacker_t *bitpacker;
+       chunk_t pk;
+
+       pk = this->pubkey;
+       memcpy(rho.ptr, pk.ptr, rho.len);
+       pk = chunk_skip(pk, rho.len);
+
+       /* unpack the vector t1 from the public key blob */
+       bitpacker = ml_bitpacker_create_from_data(pk);
+       for (i = 0; i < k; i++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->read_bits(bitpacker, &t1[i].f[n], ML_DSA_T1_BITS))
+                       {
+                               bitpacker->destroy(bitpacker);
+                               return FALSE;
+                       }
+               }
+       }
+       bitpacker->destroy(bitpacker);
+
+       return TRUE;
+}
+
+/**
+ * Decodes a signature (sigDecode).
+ *
+ * Algorithm 27 in FIPS 204.
+ */
+static bool decode_signature(private_public_key_t *this, chunk_t signature,
+                                                        chunk_t c_tilde, ml_dsa_poly_t *z, ml_dsa_poly_t *h)
+{
+       const u_int k = this->params->k;
+       const u_int l = this->params->l;
+       const u_int gamma1_exp = this->params->gamma1_exp;
+       const u_int omega = this->params->omega;
+       const size_t sig_len = this->params->sig_len;
+       ml_bitpacker_t *bitpacker;
+       u_int i, j, n, first, index = 0;
+       uint32_t value;
+       chunk_t sig;
+
+       if (signature.len != sig_len)
+       {
+               DBG1(DBG_LIB, "error: the size of the ML-DSA signature is %u bytes "
+                                         "instead of %u bytes", signature.len, sig_len);
+               return FALSE;
+       }
+       sig = signature;
+
+       /* extract byte string c_tilde */
+       memcpy(c_tilde.ptr, sig.ptr, c_tilde.len);
+       sig = chunk_skip(sig, c_tilde.len);
+
+       /* decode vector of polynomials z from packed format */
+       bitpacker = ml_bitpacker_create_from_data(sig);
+       for (j = 0; j < l; j++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->read_bits(bitpacker, &value, 1 +  gamma1_exp))
+                       {
+                               bitpacker->destroy(bitpacker);
+                               chunk_free(&signature);
+                               return FALSE;
+                       }
+                       z[j].f[n] = (1 << gamma1_exp) - (int32_t)value;
+               }
+       }
+       bitpacker->destroy(bitpacker);
+       sig = chunk_skip(sig, 32 * (1 + gamma1_exp) * l);
+
+       /* decode vector of polynomials with binary coefficients h (HintBitUnpack)
+        * Algorithm 21 in FIPS 204.
+        */
+       memset(h, 0x00, k * ML_DSA_N * sizeof(int32_t));
+
+       for (i = 0; i < k; i++)
+       {
+               if (sig.ptr[omega + i] < index)
+               {
+                       DBG1(DBG_LIB, "error: signature with a decreasing maximum hint index");
+                       return FALSE;
+               }
+               if (sig.ptr[omega + i] > omega)
+               {
+                       DBG1(DBG_LIB, "error: signature with an oversized maximum hint index");
+                       return FALSE;
+               }
+
+               first = index;
+               while (index < sig.ptr[omega + i])
+               {
+                       if (index > first && sig.ptr[index-1] >= sig.ptr[index])
+                       {
+                               DBG1(DBG_LIB, "error: signature with non-increasing hint positions");
+                               return FALSE;
+                       }
+                       h[i].f[sig.ptr[index++]] = 1;
+               }
+       }
+       while (index < omega)
+       {
+               if (sig.ptr[index++] != 0x00)
+               {
+                       DBG1(DBG_LIB, "error: signature with non-zeroed unused hint bit");
+                       return FALSE;
+               }
+       }
+
+       return TRUE;
+}
+
+/**
+ * Samples a polynomial with uniformly random coefficients in [0,Q-1]
+ * by performing rejection sampling on a SHAKE-128 output stream G.
+ *
+ * Algorithm 30 in FIPS 204.
+ */
+static bool rej_ntt_poly(xof_t *G, chunk_t seed, ml_dsa_poly_t *a)
+{
+       uint8_t c[3];
+       uint32_t t;
+       u_int n = 0;
+
+       if (!G->set_seed(G, seed))
+       {
+               return FALSE;
+       }
+
+       while (n < ML_DSA_N)
+       {
+               if (!G->get_bytes(G, sizeof(c), c))
+               {
+                       return FALSE;
+               }
+
+               /* Algorithm 14 in FIPS 204 (CoeffFromThreeBytes) */
+               t = (uint32_t)c[2] << 16 | (uint32_t)c[1] << 8 | (uint32_t)c[0];
+               t &= 0x7fffff;
+
+               if (t < ML_DSA_Q)
+               {
+                       a->f[n++] = t;
+               }
+       }
+
+       return TRUE;
+}
+
+/**
+ * Samples a k x l matrix A.
+ *
+ * Algorithm 32 in FIPS 204.
+ */
+bool ml_dsa_expand_a(const ml_dsa_params_t *params, xof_t *G, chunk_t rho,
+                                        ml_dsa_poly_t *a)
+{
+       const u_int k = params->k;
+       const u_int l = params->l;
+       u_int i, j, ctr = 0;
+       chunk_t seed;
+
+       seed = chunk_alloca(ML_DSA_SEED_LEN + 2);
+       memcpy(seed.ptr, rho.ptr, rho.len);
+
+       for (i = 0; i < k; i++)
+       {
+               for (j = 0; j < l; j++)
+               {
+                       seed.ptr[ML_DSA_SEED_LEN+1] = (uint8_t)i;
+                       seed.ptr[ML_DSA_SEED_LEN]   = (uint8_t)j;
+
+                       if (!rej_ntt_poly(G, seed, &a[ctr++]))
+                       {
+                               return FALSE;
+                       }
+               }
+       }
+
+       return TRUE;
+}
+
+/**
+ * samples a polynomial c with coefficients from {-1, 0, 1}
+ * and Hamming weight tau <= 64.
+ *
+ * Algorithm 29 in FIPS 204.
+ */
+bool ml_dsa_sample_in_ball(xof_t *H, int32_t tau, chunk_t rho, ml_dsa_poly_t *c)
+{
+       uint8_t s[8], b;
+       uint64_t signs = 0;
+       u_int i;
+
+       if (!H->set_seed(H, rho) ||
+               !H->get_bytes(H, 8, s))
+       {
+               return FALSE;
+       }
+       for (i = 0; i < 8; i++)
+       {
+               signs |= (uint64_t)s[i] << 8*i;
+       }
+       for (i = 0; i < ML_DSA_N; i++)
+       {
+               c->f[i] = 0;
+       }
+       for (i = ML_DSA_N - tau; i < ML_DSA_N; i++)
+       {
+               do
+               {
+                       if (!H->get_bytes(H, 1, &b))
+                       {
+                               return FALSE;
+                       }
+               } while (b > i);
+
+               c->f[i] = c->f[b];
+               c->f[b] = 1 - 2*(signs & 1);
+               signs >>= 1;
+       }
+
+       return TRUE;
+}
+
+/**
+ * Encodes a polynomial vector w1 into a byte string.
+ *
+ * Algorithm 28 in FIPS 204.
+ */
+bool ml_dsa_w1_encode(u_int k, ml_dsa_poly_t *w1, chunk_t w1_enc, u_int d)
+{
+       ml_bitpacker_t *bitpacker;
+       u_int j, n;
+
+       bitpacker = ml_bitpacker_create(w1_enc);
+       for (j = 0; j < k; j++)
+       {
+               for (n = 0; n < ML_DSA_N; n++)
+               {
+                       if (!bitpacker->write_bits(bitpacker, w1[j].f[n], d))
+                       {
+                               bitpacker->destroy(bitpacker);
+                               return FALSE;
+                       }
+               }
+       }
+       bitpacker->destroy(bitpacker);
+
+       return TRUE;
+}
+
+METHOD(public_key_t, verify, bool,
+       private_public_key_t *this, signature_scheme_t scheme,
+       void *params, chunk_t data, chunk_t signature)
+{
+       const u_int k = this->params->k;
+       const u_int l = this->params->l;
+       const int32_t bound = (1 << this->params->gamma1_exp) - this->params->beta;
+       pqc_params_t pqc_params;
+       ml_dsa_poly_t a[k*l], t1[k], z[l], h[k], w1[k], c;
+       chunk_t rho, c_tilde, c_tilde2, w1_enc, tr, seed, mu;
+
+       if (key_type_from_signature_scheme(scheme) != this->type)
+       {
+               DBG1(DBG_LIB, "signature scheme %N not supported",
+                                          signature_scheme_names, scheme);
+               return FALSE;
+       }
+
+       rho      = chunk_alloca(ML_DSA_SEED_LEN);
+       c_tilde  = chunk_alloca(this->params->lambda / 4);
+       c_tilde2 = chunk_alloca(this->params->lambda / 4);
+       w1_enc   = chunk_alloca(32 * this->params->gamma2_d * k);
+       tr       = chunk_alloca(ML_DSA_TR_LEN);
+       mu       = chunk_alloca(ML_DSA_MU_LEN);
+
+       if (!decode_public_key(this, rho, t1) ||
+               !decode_signature(this, signature, c_tilde, z, h) ||
+               !ml_dsa_poly_check_bound_vec(l, z, bound) ||
+               !ml_dsa_expand_a(this->params, this->G, rho, a))
+       {
+               return FALSE;
+       }
+
+       /* compute tr as a SHAKE-256 digest over the public key blob */
+       if (!this->H->set_seed(this->H, this->pubkey) ||
+               !this->H->get_bytes(this->H, ML_DSA_TR_LEN, tr.ptr))
+       {
+               return FALSE;
+       }
+
+       /* set PQC signature params */
+       if (!pqc_params_create(params, &pqc_params))
+       {
+               return FALSE;
+       }
+
+       /* compute message representative mu */
+       seed = chunk_cat("cmc", tr, pqc_params.pre_ctx, data);
+       if (!this->H->set_seed(this->H, seed) ||
+               !this->H->get_bytes(this->H, ML_DSA_MU_LEN, mu.ptr))
+       {
+               chunk_free(&seed);
+               return FALSE;
+       }
+       chunk_free(&seed);
+
+       /* verifier's challenge */
+       if (!ml_dsa_sample_in_ball(this->H, this->params->tau, c_tilde, &c))
+       {
+               return FALSE;
+       }
+
+       /* compute w1 = a * z - c * 2^d * t1 */
+       ml_dsa_poly_ntt(&c);
+       ml_dsa_poly_ntt_vec(l, z);
+       ml_dsa_poly_mult_mat(k, l, a, z, w1);
+       ml_dsa_poly_shift_left_vec(k, t1);
+       ml_dsa_poly_ntt_vec(k, t1);
+       ml_dsa_poly_mult_const_vec(k, &c, t1, t1);
+       ml_dsa_poly_sub_vec(k, w1, t1, w1);
+       ml_dsa_poly_reduce_vec(k, w1);
+       ml_dsa_poly_inv_ntt_vec(k, w1);
+
+       /* reconstruct w1 */
+       ml_dsa_poly_cond_add_q_vec(k, w1);
+       ml_dsa_poly_use_hint_vec(k, w1, h, w1, this->params->gamma2);
+
+       /* compress the w1 vector into a byte string */
+       if (!ml_dsa_w1_encode(k, w1, w1_enc, this->params->gamma2_d))
+       {
+               return FALSE;
+       }
+
+       /* compute commitment hash c_tilde2 */
+       seed = chunk_cat("cc", mu, w1_enc);
+       if (!this->H->set_seed(this->H, seed) ||
+               !this->H->get_bytes(this->H, this->params->lambda/4, c_tilde2.ptr))
+       {
+               chunk_free(&seed);
+               return FALSE;
+       }
+       chunk_free(&seed);
+
+       return chunk_equals_const(c_tilde2, c_tilde);
+}
+
+METHOD(public_key_t, encrypt_, bool,
+       private_public_key_t *this, encryption_scheme_t scheme,
+       void *params, chunk_t crypto, chunk_t *plain)
+{
+       DBG1(DBG_LIB, "encryption scheme %N not supported", encryption_scheme_names,
+                scheme);
+       return FALSE;
+}
+
+METHOD(public_key_t, get_keysize, int,
+       private_public_key_t *this)
+{
+       return BITS_PER_BYTE * get_public_key_size(this->type);
+}
+
+/**
+ * Generate two types of ML-DSA fingerprints.
+ */
+bool ml_dsa_fingerprint(chunk_t pubkey, key_type_t type,
+                                               cred_encoding_type_t enc_type, chunk_t *fp)
+{
+       chunk_t encoding;
+       hasher_t *hasher;
+
+       *fp = chunk_empty;
+
+       switch (enc_type)
+       {
+               case KEYID_PUBKEY_SHA1:
+                       encoding = chunk_clone(pubkey);
+                       break;
+               case KEYID_PUBKEY_INFO_SHA1:
+                       encoding = public_key_info_encode(pubkey, key_type_to_oid(type));
+                       break;
+               default:
+                       return FALSE;
+       }
+
+       hasher = lib->crypto->create_hasher(lib->crypto, HASH_SHA1);
+       if (!hasher || !hasher->allocate_hash(hasher, encoding, fp))
+       {
+               DBG1(DBG_LIB, "SHA1 hash algorithm not supported");
+               DESTROY_IF(hasher);
+               chunk_free(&encoding);
+               return FALSE;
+       }
+       hasher->destroy(hasher);
+       chunk_free(&encoding);
+
+       return TRUE;
+}
+
+METHOD(public_key_t, get_fingerprint, bool,
+       private_public_key_t *this, cred_encoding_type_t type, chunk_t *fp)
+{
+       bool success;
+
+       if (lib->encoding->get_cache(lib->encoding, type, this, fp))
+       {
+               return TRUE;
+       }
+
+       success = ml_dsa_fingerprint(this->pubkey, this->type, type, fp);
+       if (success)
+       {
+               lib->encoding->cache(lib->encoding, type, this, fp);
+       }
+
+       return success;
+}
+
+METHOD(public_key_t, get_encoding, bool,
+       private_public_key_t *this, cred_encoding_type_t type, chunk_t *encoding)
+{
+       bool success = TRUE;
+       int oid;
+
+       oid = key_type_to_oid(this->type);
+       *encoding = public_key_info_encode(this->pubkey, oid);
+
+       if (type != PUBKEY_SPKI_ASN1_DER)
+       {
+               chunk_t asn1_encoding = *encoding;
+
+               success = lib->encoding->encode(lib->encoding, type,
+                                               NULL, encoding, CRED_PART_PUB_ASN1_DER,
+                                               asn1_encoding, CRED_PART_END);
+               chunk_clear(&asn1_encoding);
+       }
+
+       return success;
+}
+
+METHOD(public_key_t, get_ref, public_key_t*,
+       private_public_key_t *this)
+{
+       ref_get(&this->ref);
+       return &this->public;
+}
+
+METHOD(public_key_t, destroy, void,
+       private_public_key_t *this)
+{
+       if (ref_put(&this->ref))
+       {
+               lib->encoding->clear_cache(lib->encoding, this);
+               DESTROY_IF(this->G);
+               DESTROY_IF(this->H);
+               chunk_free(&this->pubkey);
+               free(this);
+       }
+}
+
+/**
+ * Generic private constructor
+ */
+static private_public_key_t *create_empty(key_type_t type, chunk_t pubkey)
+{
+       private_public_key_t *this;
+       const ml_dsa_params_t *params;
+
+       params = ml_dsa_params_get(type);
+       if (!params)
+       {
+               return NULL;
+       }
+
+       INIT(this,
+               .public = {
+                       .get_type = _get_type,
+                       .verify = _verify,
+                       .encrypt = _encrypt_,
+                       .get_keysize = _get_keysize,
+                       .equals = public_key_equals,
+                       .get_fingerprint = _get_fingerprint,
+                       .has_fingerprint = public_key_has_fingerprint,
+                       .get_encoding = _get_encoding,
+                       .get_ref = _get_ref,
+                       .destroy = _destroy,
+               },
+               .type = type,
+               .params = params,
+               .pubkey = chunk_clone(pubkey),
+               .G = lib->crypto->create_xof(lib->crypto, XOF_SHAKE_128),
+               .H = lib->crypto->create_xof(lib->crypto, XOF_SHAKE_256),
+               .ref = 1,
+       );
+
+       if (!this->G || !this->H)
+       {
+               destroy(this);
+               return NULL;
+       }
+
+       return this;
+}
+
+/**
+ * Check if ML-DSA key type is supported.
+ */
+bool ml_dsa_type_supported(key_type_t type)
+{
+       switch (type)
+       {
+               case KEY_ML_DSA_44:
+               case KEY_ML_DSA_65:
+               case KEY_ML_DSA_87:
+                       return TRUE;
+               default:
+                       return FALSE;
+       }
+}
+
+/*
+ * Described in header
+ */
+public_key_t *ml_dsa_public_key_load(key_type_t type, va_list args)
+{
+       private_public_key_t *this;
+       chunk_t pkcs1, blob = chunk_empty;
+       size_t pubkey_len;
+
+       while (TRUE)
+       {
+               switch (va_arg(args, builder_part_t))
+               {
+                       case BUILD_BLOB:
+                               blob = va_arg(args, chunk_t);
+                               continue;
+                       case BUILD_BLOB_ASN1_DER:
+                               pkcs1 = va_arg(args, chunk_t);
+                               type = public_key_info_decode(pkcs1, &blob);
+                               continue;
+                       case BUILD_END:
+                               break;
+                       default:
+                               return NULL;
+               }
+               break;
+       }
+
+       if (!ml_dsa_type_supported(type) || blob.len == 0)
+       {
+               return NULL;
+       }
+       pubkey_len = get_public_key_size(type);
+       if (blob.len != pubkey_len)
+       {
+               DBG1(DBG_LIB, "the size of the loaded ML-DSA public key is %u bytes "
+                                         "instead of %u bytes", blob.len, pubkey_len);
+               return NULL;
+       }
+
+       this = create_empty(type, blob);
+       if (!this)
+       {
+               return NULL;
+       }
+
+       return &this->public;
+}
diff --git a/src/libstrongswan/plugins/ml/ml_dsa_public_key.h b/src/libstrongswan/plugins/ml/ml_dsa_public_key.h
new file mode 100644 (file)
index 0000000..c411c95
--- /dev/null
@@ -0,0 +1,39 @@
+/*
+ * Copyright (C) 2024 Andreas Steffen, strongSec GmbH
+ *
+ * Copyright (C) secunet Security Networks AG
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+ * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * for more details.
+ */
+
+/**
+ * @defgroup wolfssl_ml_dsa_public_key wolfssl_ml_dsa_public_key
+ * @{ @ingroup wolfssl_p
+ */
+
+#ifndef ML_DSA_PUBLIC_KEY_H_
+#define ML_DSA_PUBLIC_KEY_H_
+
+#include <credentials/builder.h>
+#include <credentials/keys/public_key.h>
+
+/**
+ * Load an ML-DSA public key.
+ *
+ * Accepts a BUILD_BLOB or BUILD_BLOB_ASN1_DER argument.
+ *
+ * @param type         key type, must be KEY_ML_DSA_44, KEY_ML_DSA_65 or KEY_ML_DSA_87
+ * @param args         builder_part_t argument list
+ * @return                     loaded key, NULL on failure
+ */
+public_key_t *ml_dsa_public_key_load(key_type_t type, va_list args);
+
+#endif /** ML_DSA_PUBLIC_KEY_H_ @}*/
index 9da72ea6114cd330620411f87eb94d38ac0274cf..da3b9672870b6806d886d9fbd0e5db66dd13f8e9 100644 (file)
@@ -16,7 +16,7 @@
 
 #include "ml_bitpacker.h"
 #include "ml_kem.h"
-#include "ml_params.h"
+#include "ml_kem_params.h"
 #include "ml_poly.h"
 #include "ml_utils.h"
 
similarity index 98%
rename from src/libstrongswan/plugins/ml/ml_params.c
rename to src/libstrongswan/plugins/ml/ml_kem_params.c
index 89b1e6dbe56c0239fc9e28304352afd557f778f8..d926ea4eaee001ddd11b26405d86783ea8ba3576 100644 (file)
@@ -14,7 +14,7 @@
  * for more details.
  */
 
-#include "ml_params.h"
+#include "ml_kem_params.h"
 
 /*
  * Described in header
similarity index 94%
rename from src/libstrongswan/plugins/ml/ml_params.h
rename to src/libstrongswan/plugins/ml/ml_kem_params.h
index 280aed6155a6bc0be0ca823c240d6022c215beba..bc15c5b73bf1a567397daa9422707fc8637d9ee1 100644 (file)
  */
 
 /**
- * @defgroup ml_params ml_params
+ * @defgroup ml_kem_params ml_kem_params
  * @{ @ingroup ml_p
  */
 
-#ifndef ML_PARAMS_H_
-#define ML_PARAMS_H_
+#ifndef ML_KEM_PARAMS_H_
+#define ML_KEM_PARAMS_H_
 
 #include <crypto/key_exchange.h>
 
@@ -107,4 +107,4 @@ extern const uint16_t ml_kem_zetas[128];
  */
 const ml_kem_params_t *ml_kem_params_get(key_exchange_method_t method);
 
-#endif /** ML_PARAMS_H_ @}*/
+#endif /** ML_KEM_PARAMS_H_ @}*/
index 752c20da7c726bc52988a85c601026afc20fee5f..f50b786cb7d008f3c522c22679a578d7f59be518 100644 (file)
@@ -19,6 +19,8 @@
 #include <plugins/plugin.h>
 
 #include "ml_kem.h"
+#include "ml_dsa_public_key.h"
+#include "ml_dsa_private_key.h"
 
 typedef struct private_plugin_t private_plugin_t;
 
@@ -39,6 +41,25 @@ METHOD(plugin_t, get_name, char*,
        return "ml";
 }
 
+/**
+ * Helper macros to declare dependencies for ML-DSA.
+ */
+#define ML_DSA_DEPS \
+       PLUGIN_DEPENDS(XOF, XOF_SHAKE_128), \
+       PLUGIN_DEPENDS(XOF, XOF_SHAKE_256)
+
+#define ML_DSA_PUBKEY_DEPS \
+       ML_DSA_DEPS, \
+       PLUGIN_DEPENDS(HASHER, HASH_SHA1)
+
+#define ML_DSA_PRIVKEY_DEPS \
+       ML_DSA_DEPS, \
+       PLUGIN_DEPENDS(RNG, RNG_STRONG)
+
+#define ML_DSA_PRIVKEY_GEN_DEPS \
+       ML_DSA_PRIVKEY_DEPS, \
+       PLUGIN_DEPENDS(RNG, RNG_TRUE)
+
 METHOD(plugin_t, get_features, int,
        private_plugin_t *this, plugin_feature_t *features[])
 {
@@ -62,6 +83,41 @@ METHOD(plugin_t, get_features, int,
                                PLUGIN_DEPENDS(XOF, XOF_SHAKE_128),
                                PLUGIN_DEPENDS(XOF, XOF_SHAKE_256),
                                PLUGIN_DEPENDS(RNG, RNG_STRONG),
+               PLUGIN_REGISTER(PUBKEY, ml_dsa_public_key_load, TRUE),
+                       PLUGIN_PROVIDE(PUBKEY, KEY_ML_DSA_44),
+                               ML_DSA_PUBKEY_DEPS,
+                       PLUGIN_PROVIDE(PUBKEY, KEY_ML_DSA_65),
+                               ML_DSA_PUBKEY_DEPS,
+                       PLUGIN_PROVIDE(PUBKEY, KEY_ML_DSA_87),
+                               ML_DSA_PUBKEY_DEPS,
+                       PLUGIN_PROVIDE(PUBKEY, KEY_ANY),
+                               ML_DSA_PUBKEY_DEPS,
+               PLUGIN_REGISTER(PRIVKEY, ml_dsa_private_key_load, TRUE),
+                       PLUGIN_PROVIDE(PRIVKEY, KEY_ML_DSA_44),
+                               ML_DSA_PRIVKEY_DEPS,
+                       PLUGIN_PROVIDE(PRIVKEY, KEY_ML_DSA_65),
+                               ML_DSA_PRIVKEY_DEPS,
+                       PLUGIN_PROVIDE(PRIVKEY, KEY_ML_DSA_87),
+                               ML_DSA_PRIVKEY_DEPS,
+               PLUGIN_REGISTER(PRIVKEY_GEN, ml_dsa_private_key_gen, FALSE),
+                       PLUGIN_PROVIDE(PRIVKEY_GEN, KEY_ML_DSA_44),
+                               ML_DSA_PRIVKEY_GEN_DEPS,
+                       PLUGIN_PROVIDE(PRIVKEY_GEN, KEY_ML_DSA_65),
+                               ML_DSA_PRIVKEY_GEN_DEPS,
+                       PLUGIN_PROVIDE(PRIVKEY_GEN, KEY_ML_DSA_87),
+                               ML_DSA_PRIVKEY_GEN_DEPS,
+               PLUGIN_PROVIDE(PRIVKEY_SIGN, SIGN_ML_DSA_44),
+                       ML_DSA_PRIVKEY_DEPS,
+               PLUGIN_PROVIDE(PRIVKEY_SIGN, SIGN_ML_DSA_65),
+                       ML_DSA_PRIVKEY_DEPS,
+               PLUGIN_PROVIDE(PRIVKEY_SIGN, SIGN_ML_DSA_87),
+                       ML_DSA_PRIVKEY_DEPS,
+               PLUGIN_PROVIDE(PUBKEY_VERIFY, SIGN_ML_DSA_44),
+                       ML_DSA_PUBKEY_DEPS,
+               PLUGIN_PROVIDE(PUBKEY_VERIFY, SIGN_ML_DSA_65),
+                       ML_DSA_PUBKEY_DEPS,
+               PLUGIN_PROVIDE(PUBKEY_VERIFY, SIGN_ML_DSA_87),
+                       ML_DSA_PUBKEY_DEPS,
        };
        *features = f;
        return countof(f);
index 3863598edb2e1963e9a9d1282fea670f0c3e7666..510282c83d45325734ec2d380769abc8853ef476 100644 (file)
@@ -22,7 +22,7 @@
 #ifndef ML_POLY_H_
 #define ML_POLY_H_
 
-#include "ml_params.h"
+#include "ml_kem_params.h"
 
 typedef struct ml_poly_t ml_poly_t;
 
index 7b92f53aef54f178707d4dffcb5dad496aae4f33..cc74c3900e2b9e53fc27b2b1c1102ce66d7b8749 100644 (file)
@@ -53,3 +53,75 @@ void ml_write_bytes_le(uint8_t *buf, size_t len, uint32_t val)
                val >>= 8;
        }
 }
+
+/*
+ * Described in header
+ */
+void ml_decompose(int32_t a, int32_t *a0, int32_t *a1, int32_t gamma2)
+{
+       int32_t t0, t1;
+
+       t1  = (a + 127) >> 7;
+
+       if (gamma2 == (ML_DSA_Q-1)/32)
+       {
+               t1  = (t1 * 1025 + (1 << 21)) >> 22;
+               t1 &= 15;
+       }
+       else
+       {
+               t1  = (t1 * 11275 + (1 << 23)) >> 24;
+               t1 ^= ((43 - t1) >> 31) & t1;
+       }
+
+       t0 = a - t1 * 2 * gamma2;
+       t0 -= (((ML_DSA_Q-1)/2 - t0) >> 31) & ML_DSA_Q;
+
+       *a0 = t0;
+       *a1 = t1;
+}
+
+/*
+ * Described in header
+ */
+int32_t ml_use_hint(int32_t a, int32_t hint, int32_t gamma2)
+{
+       int32_t a0, a1;
+
+       ml_decompose(a, &a0, &a1, gamma2);
+
+       if (hint == 0)
+       {
+               return a1;
+       }
+       if (gamma2 == (ML_DSA_Q-1)/32)
+       {
+               if (a0 > 0)
+               {
+                       return (a1 + 1) & 15;
+               }
+               else
+               {
+                       return (a1 - 1) & 15;
+               }
+       }
+       else
+    {
+               if (a0 > 0)
+               {
+                       return (a1 == 43) ?  0 : a1 + 1;
+               }
+               else
+               {
+                       return (a1 ==  0) ? 43 : a1 - 1;
+               }
+    }
+}
+
+/*
+ * Described in header
+ */
+int32_t ml_make_hint(int32_t a0, int32_t a1, int32_t gamma2)
+{
+       return (a0 > gamma2 || a0 < -gamma2 || (a0 == -gamma2 && a1 != 0)) ? 1 : 0;
+}
\ No newline at end of file
index 6f5262259671b61475b6e5b7033ee198e9d7f2c9..41b9e8c5317cee34158629ff6302a0f16bd4d56a 100644 (file)
@@ -22,7 +22,8 @@
 #ifndef ML_UTILS_H_
 #define ML_UTILS_H_
 
-#include "ml_params.h"
+#include "ml_kem_params.h"
+#include "ml_dsa_params.h"
 
 /**
  * Returns a mod q for a in [0,2*q) in constant time.
@@ -38,6 +39,21 @@ static inline uint16_t ml_reduce_modq(uint16_t a)
        return (mask & a) | (~mask & diff);
 }
 
+/**
+ * Computes a * 2^-32 mod q montgomery reduction
+ *
+ * Algorithm 49 in FIPS 204.
+ */
+static inline int32_t ml_montgomery_reduce(int64_t a)
+{
+       int32_t r;
+
+       r = (int64_t)(int32_t)a * ML_DSA_QINV;
+       r = (a - (int64_t)r * ML_DSA_Q) >> 32;
+
+       return r;
+}
+
 /**
  * Used to assign the given value based on a condition in constant time and
  * without branching.
@@ -66,4 +82,41 @@ uint32_t ml_read_bytes_le(uint8_t *buf, size_t len);
  */
 void ml_write_bytes_le(uint8_t *buf, size_t len, uint32_t val);
 
+/**
+ * Decompose a into (a1, a0) such that a ≡ a1 * (2 * gamma2) + a0 mod q.
+ *
+ * Algorithm 36 of FIPS 204.
+ *
+ * @param a                    input value to be decomposed
+ * @param a0           low  bits of a
+ * @param a1           high bits of a
+ * @param gamma2       parameter gamma2
+ */
+void ml_decompose(int32_t a, int32_t *a0, int32_t *a1, int32_t gamma2);
+
+/**
+ * Return the high bits a1 of a adjusted according to hint h.
+ *
+ * Algorithm 40 of FIPS 204.
+ *
+ * @param a                    input value to be adjusted
+ * @param hint         hint (0 or 1)
+ * @param gamma2       parameter gamma2
+ * @return                     adjusted high bits of a
+ */
+int32_t ml_use_hint(int32_t a, int32_t hint, int32_t gamma2);
+
+/**
+ * Compute a hint bit indicating whether the low bits a0 of the
+ * input element overflow into the high bits a1.
+ *
+ * Algorithm 39 in FIPS 204.
+ *
+ * @param a0           low bits
+ * @param a1           high bits
+ * @param gamma2       parameter gamma2
+ * @return                     hint bit (0 or 1)
+ */
+int32_t ml_make_hint(int32_t a0, int32_t a1, int32_t gamma2);
+
 #endif /** ML_UTILS_H_ @}*/