]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
More seed and private key checks for ML-DSA
authorViktor Dukhovni <openssl-users@dukhovni.org>
Fri, 21 Feb 2025 08:47:36 +0000 (19:47 +1100)
committerViktor Dukhovni <openssl-users@dukhovni.org>
Tue, 25 Feb 2025 01:49:49 +0000 (12:49 +1100)
- Check seed/key consistency when generating from a seed and the private
  key is also given.
- Improve error reporting when the private key does not match an
  explicit public key.

Reviewed-by: Tim Hudson <tjh@openssl.org>
Reviewed-by: Shane Lontis <shane.lontis@oracle.com>
(Merged from https://github.com/openssl/openssl/pull/26865)

crypto/ml_dsa/ml_dsa_encoders.c
crypto/ml_dsa/ml_dsa_key.c
include/crypto/ml_dsa.h
providers/implementations/keymgmt/ml_dsa_kmgmt.c
test/recipes/15-test_ml_dsa_codecs.t
test/recipes/15-test_ml_kem_codecs.t

index b404ddc6357d24d24cd432337393ebef5fd2212a..078d25122c2cdb5db5ee582b00dad08d3208ba15 100644 (file)
@@ -8,7 +8,9 @@
  */
 
 #include <openssl/byteorder.h>
+#include <openssl/err.h>
 #include <openssl/evp.h>
+#include <openssl/proverr.h>
 #include "ml_dsa_hash.h"
 #include "ml_dsa_key.h"
 #include "ml_dsa_sign.h"
@@ -811,8 +813,13 @@ int ossl_ml_dsa_sk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len)
      * the |tr| value in the private key, else the key was corrupted.
      */
     if (!ossl_ml_dsa_key_public_from_private(key)
-            || memcmp(input_tr, key->tr, sizeof(input_tr)) != 0)
+            || memcmp(input_tr, key->tr, sizeof(input_tr)) != 0) {
+        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
+                       "%s private key does not match its pubkey part",
+                       key->params->alg);
+        ossl_ml_dsa_key_reset(key);
         goto err;
+    }
 
     return 1;
  err:
index be041fc4e300ea43d9d41eb4b516405c1b998660..48dbd5733e6980624d6b366f293a6c0d2b2d2e6a 100644 (file)
@@ -9,6 +9,7 @@
 
 #include <openssl/core_dispatch.h>
 #include <openssl/core_names.h>
+#include <openssl/err.h>
 #include <openssl/params.h>
 #include <openssl/proverr.h>
 #include <openssl/rand.h>
@@ -142,10 +143,19 @@ void ossl_ml_dsa_key_free(ML_DSA_KEY *key)
  */
 void ossl_ml_dsa_key_reset(ML_DSA_KEY *key)
 {
-    vector_zero(&key->s2);
-    vector_zero(&key->s1);
-    vector_zero(&key->t0);
-    vector_free(&key->s1);
+    /*
+     * The allocation for |s1.poly| subsumes those for |s2| and |t0|, which we
+     * must not access after |s1|'s poly is freed.
+     */
+    if (key->s1.poly != NULL) {
+        vector_zero(&key->s1);
+        vector_zero(&key->s2);
+        vector_zero(&key->t0);
+        vector_free(&key->s1);
+        key->s2.poly = NULL;
+        key->t0.poly = NULL;
+    }
+    /* The |t1| vector is public and allocated separately */
     vector_free(&key->t1);
     OPENSSL_cleanse(key->K, sizeof(key->K));
     OPENSSL_free(key->pub_encoding);
@@ -447,6 +457,8 @@ err:
 int ossl_ml_dsa_generate_key(ML_DSA_KEY *out)
 {
     size_t seed_len = ML_DSA_SEED_BYTES;
+    uint8_t *sk;
+    int ret;
 
     if (out->seed == NULL) {
         if ((out->seed = OPENSSL_malloc(seed_len)) == NULL)
@@ -458,9 +470,22 @@ int ossl_ml_dsa_generate_key(ML_DSA_KEY *out)
         }
     }
     /* We're generating from a seed, drop private prekey encoding */
-    OPENSSL_free(out->priv_encoding);
+    sk = out->priv_encoding;
     out->priv_encoding = NULL;
-    return keygen_internal(out);
+    if (sk == NULL) {
+        ret = keygen_internal(out);
+    } else {
+        if ((ret = keygen_internal(out)) != 0
+            && memcmp(out->priv_encoding, sk, out->params->sk_len) != 0) {
+            ret = 0;
+            ossl_ml_dsa_key_reset(out);
+            ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
+                           "explicit %s private key does not match seed",
+                           out->params->alg);
+        }
+        OPENSSL_free(sk);
+    }
+    return ret;
 }
 
 /**
index fd21a40889339a039ee59e23a7c25ed145dd7a40..082a992b13ff70304fd28e4346cf4311d108a46c 100644 (file)
@@ -106,10 +106,6 @@ __owur int ossl_ml_dsa_key_public_from_private(ML_DSA_KEY *key);
 __owur int ossl_ml_dsa_pk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len);
 __owur int ossl_ml_dsa_sk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len);
 
-__owur int ossl_ml_dsa_key_public_from_private(ML_DSA_KEY *key);
-__owur int ossl_ml_dsa_pk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len);
-__owur int ossl_ml_dsa_sk_decode(ML_DSA_KEY *key, const uint8_t *in, size_t in_len);
-
 __owur int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu,
                             const uint8_t *msg, size_t msg_len,
                             const uint8_t *context, size_t context_len,
index 24e8ceb2f5a3a3f99a3e9fb2f62cb2737d1e8689..ea549bc95d7406162c5df8803c471473546b5d6d 100644 (file)
@@ -236,7 +236,7 @@ static int ml_dsa_key_fromdata(ML_DSA_KEY *key, const OSSL_PARAM params[],
     if (seed_len != 0
         && (sk_len == 0
             || (ossl_ml_dsa_key_get_prov_flags(key) & ML_DSA_KEY_PREFER_SEED))) {
-        if (!ossl_ml_dsa_set_prekey(key, 0, 0, seed, seed_len, NULL, 0))
+        if (!ossl_ml_dsa_set_prekey(key, 0, 0, seed, seed_len, sk, sk_len))
             return 0;
         if (!ossl_ml_dsa_generate_key(key)) {
             ERR_raise(ERR_LIB_PROV, PROV_R_FAILED_TO_GENERATE_KEY);
@@ -251,9 +251,15 @@ static int ml_dsa_key_fromdata(ML_DSA_KEY *key, const OSSL_PARAM params[],
     }
 
     /* Error if the supplied public key does not match the generated key */
-    return pk_len == 0
+    if (pk_len == 0
         || seed_len + sk_len == 0
-        || memcmp(ossl_ml_dsa_key_get_pub(key), pk, pk_len) == 0;
+        || memcmp(ossl_ml_dsa_key_get_pub(key), pk, pk_len) == 0)
+        return 1;
+    ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
+                   "explicit %s public key does not match private",
+                   key_params->alg);
+    ossl_ml_dsa_key_reset(key);
+    return 0;
 }
 
 static int ml_dsa_import(void *keydata, int selection, const OSSL_PARAM params[])
index 274c830d7181b59e3774f21647a23c6c3ba25661..9f4d5e595fa3597f43d303a3c8aea2f84a3dfc1b 100644 (file)
@@ -25,13 +25,18 @@ my @formats = qw(seed-priv priv-only seed-only oqskeypair bare-seed bare-priv);
 plan skip_all => "ML-DSA isn't supported in this build"
     if disabled("ml-dsa");
 
-plan tests => @algs * (16 + 10 * @formats);
+plan tests => @algs * (23 + 10 * @formats);
 my $seed = join ("", map {sprintf "%02x", $_} (0..31));
+my $weed = join ("", map {sprintf "%02x", $_} (1..32));
 my $ikme = join ("", map {sprintf "%02x", $_} (0..31));
+my %alg = ("44" => [4, 4, 2560], "65" => [6, 5, 4032], "87" => [8, 7, 4896]);
 
 foreach my $alg (@algs) {
     my $pub = sprintf("pub-%s.pem", $alg);
     my %formats = map { ($_, sprintf("prv-%s-%s.pem", $alg, $_)) } @formats;
+    my ($k, $l, $sk_len) = @{$alg{$alg}};
+    # The number of low-bits |d| in t_0 is 13 across all the variants
+    my $t0_len = $k * 13 * 32;
 
     # (1 + 6 * @formats) tests
     my $i = 0;
@@ -40,11 +45,11 @@ foreach my $alg (@algs) {
     ok(run(app(['openssl', 'pkey', '-pubin', '-in', $in0,
                 '-outform', 'DER', '-out', $der0])));
     foreach my $f (keys %formats) {
-        my $k = $formats{$f};
+        my $kf = $formats{$f};
         my %pruned = %formats;
         delete $pruned{$f};
         my $rest = join(", ", keys %pruned);
-        my $in = data_file($k);
+        my $in = data_file($kf);
         my $der = sprintf("pub-%s.%d.der", $alg, $i);
         #
         # Compare expected DER public key with DER public key of private
@@ -75,8 +80,8 @@ foreach my $alg (@algs) {
     ok(run(app([qw(openssl pkeyutl -verify -rawin -pubin -inkey),
                 $in0, '-in', $der0, '-sigfile', $refsig],
                sprintf("Signature verify with pubkey: %s", $alg))));
-    while (my ($f, $k) = each %formats) {
-        my $sk = data_file($k);
+    while (my ($f, $kf) = each %formats) {
+        my $sk = data_file($kf);
         my $s = sprintf("sig-%s.%d.dat", $alg, $i++);
         ok(run(app([qw(openssl pkeyutl -sign -rawin -inkey), $sk, '-in', $der0,
                     qw(-pkeyopt deterministic:1 -out), $s])));
@@ -143,13 +148,67 @@ foreach my $alg (@algs) {
 
     # (2 * @formats) tests
     # Check text encoding
-    while (my ($f, $k) = each %formats) {
+    while (my ($f, $kf) = each %formats) {
         my $txt =  sprintf("prv-%s-%s.txt", $alg,
                             ($f =~ m{seed}) ? 'seed' : 'priv');
         my $out = sprintf("prv-%s-%s.txt", $alg, $f);
-        ok(run(app(['openssl', 'pkey', '-in', data_file($k),
+        ok(run(app(['openssl', 'pkey', '-in', data_file($kf),
                     '-noout', '-text', '-out', $out])));
         ok(!compare(data_file($txt), $out),
             sprintf("text form private key: %s with %s", $alg, $f));
     }
+
+    # (8 tests): Test import/load seed/priv consistency checks
+    my $real = sprintf('real-%s.der', $alg);
+    my $fake = sprintf('fake-%s.der', $alg);
+    my $mixt = sprintf('mixt-%s.der', $alg);
+    my $mash = sprintf('mash-%s.der', $alg);
+    ok(run(app([qw(openssl genpkey -algorithm), "ml-dsa-$alg",
+                qw(-provparam ml-dsa.output_formats=seed-priv -pkeyopt),
+                "hexseed:$seed", qw(-outform DER -out), $real])),
+        sprintf("create real private key: %s", $alg));
+    ok(run(app([qw(openssl genpkey -algorithm), "ml-dsa-$alg",
+                qw(-provparam ml-dsa.output_formats=seed-priv -pkeyopt),
+                "hexseed:$weed", qw(-outform DER -out), $fake])),
+        sprintf("create fake private key: %s", $alg));
+    my $realfh = IO::File->new($real, "<:raw");
+    my $fakefh = IO::File->new($fake, "<:raw");
+    local $/ = undef;
+    my $realder = <$realfh>;
+    $realfh->close();
+    my $fakeder = <$fakefh>;
+    $fakefh->close();
+    #
+    # - 20 bytes PKCS8 fixed overhead,
+    # - 4 byte private key octet string tag + length
+    # - 4 byte seed + key sequence tag + length
+    #   - 2 byte seed tag + length
+    #     - 32 byte seed
+    #   - 4 byte key tag + length
+    #     - $sk_len private key, ending in t0.
+    #
+    my $p8_len = 28 + (2 + 32) + (4 + $sk_len);
+    ok((length($realder) == $p8_len && length($fakeder) == $p8_len),
+        sprintf("Got expected DER lengths of %s seed-priv key", $alg));
+    my $mixtder = substr($realder, 0, 28 + 34)
+        . substr($fakeder, 28 + 34);
+    my $mixtfh = IO::File->new($mixt, ">:raw");
+    print $mixtfh $mixtder;
+    $mixtfh->close();
+    ok(run(app([qw(openssl pkey -inform DER -noout -in), $real])),
+        sprintf("accept valid keypair: %s", $alg));
+    ok(!run(app([qw(openssl pkey -inform DER -noout -in), $mixt])),
+        sprintf("Using seed reject mismatched private %s", $alg));
+    ok(run(app([qw(openssl pkey -provparam ml-dsa.prefer_seed=no),
+                qw(-inform DER -noout -in), $mixt])),
+        sprintf("Ignoring seed accept mismatched private %s", $alg));
+    # Mutate the t0 vector
+    my $mashder = $realder;
+    substr($mashder, -$t0_len, 1) =~ s{(.)}{chr(ord($1)^1)}es;
+    my $mashfh = IO::File->new($mash, ">:raw");
+    print $mashfh $mashder;
+    $mashfh->close();
+    ok(!run(app([qw(openssl pkey -provparam ml-dsa.prefer_seed=no),
+                 qw(-inform DER -noout -in), $mash])),
+        sprintf("reject real private and mutated public: %s", $alg));
 }
index 88e3d4deab08b64ca521029d04cc0bc0104cb668..feee099aaab87c5d32047b5974e70acf14455068 100644 (file)
@@ -176,11 +176,13 @@ foreach my $alg (@algs) {
                 qw(-provparam ml-kem.output_formats=seed-priv -pkeyopt),
                 "hexseed:$weed", qw(-outform DER -out), $fake])),
         sprintf("create fake private key: %s", $alg));
-    my $realfh = IO::File->new($real, "r");
-    my $fakefh = IO::File->new($fake, "r");
+    my $realfh = IO::File->new($real, "<:raw");
+    my $fakefh = IO::File->new($fake, "<:raw");
     local $/ = undef;
     my $realder = <$realfh>;
     my $fakeder = <$fakefh>;
+    $realfh->close();
+    $fakefh->close();
     #
     # - 20 bytes PKCS8 fixed overhead,
     # - 4 byte private key octet string tag + length
@@ -192,8 +194,9 @@ foreach my $alg (@algs) {
     #     - |ek| public key ('t' vector || 'rho')
     #     - implicit rejection 'z' seed component
     #
-    ok(length($realder) == 28 + (2 + 64) + (4 + $slen + $plen + $zlen)
-        && length($fakeder) == 28 + (2 + 64) + (4 + $slen + $plen + $zlen));
+    my $p8_len = 28 + (2 + 64) + (4 + $slen + $plen + $zlen);
+    ok((length($realder) == $p8_len && length($fakeder) == $p8_len),
+        sprintf("Got expected DER lengths of %s seed-priv key", $alg));
     my $mixtder = substr($realder, 0, 28 + 66 + 4 + $slen)
         . substr($fakeder, 28 + 66 + 4 + $slen, $plen)
         . substr($realder, 28 + 66 + 4 + $slen + $plen, $zlen);