]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
mlx: use TRIE & struct based param decoding
authorPauli <ppzgs1@gmail.com>
Fri, 20 Jun 2025 01:29:00 +0000 (11:29 +1000)
committerPauli <ppzgs1@gmail.com>
Wed, 25 Jun 2025 07:22:07 +0000 (17:22 +1000)
Also fix two bugs with the properties parameter to the set_params call:
- the parameter wasn't listed in the settables table
- the parameter was ignored unless there was a public key present

Reviewed-by: Richard Levitte <levitte@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/27859)

providers/implementations/keymgmt/mlx_kmgmt.c.in

index 3c00aa2f0dfd24f4fb0b5c2ec27736f5f0863eeb..6d497910ba360ededc89b6dd03de65ed10048776 100644 (file)
@@ -6,6 +6,9 @@
  * in the file LICENSE in the source distribution or at
  * https://www.openssl.org/source/license.html
  */
+{-
+use OpenSSL::paramnames qw(produce_param_decoder);
+-}
 
 #include <openssl/core_dispatch.h>
 #include <openssl/core_names.h>
@@ -144,6 +147,11 @@ static int mlx_kem_match(const void *vkey1, const void *vkey2, int selection)
         && EVP_PKEY_eq(key1->xkey, key2->xkey);
 }
 
+{- produce_param_decoder('ml_kem_import_export',
+                         (['PKEY_PARAM_PRIV_KEY', 'privkey', 'octet_string'],
+                          ['PKEY_PARAM_PUB_KEY',  'pubkey',  'octet_string'],
+                         )); -}
+
 typedef struct export_cb_arg_st {
     const char *algorithm_name;
     uint8_t *pubenc;
@@ -160,7 +168,7 @@ typedef struct export_cb_arg_st {
 static int export_sub_cb(const OSSL_PARAM *params, void *varg)
 {
     EXPORT_CB_ARG *sub_arg = varg;
-    const OSSL_PARAM *p = NULL;
+    struct ml_kem_import_export_st p;
     size_t len;
 
     /*
@@ -170,11 +178,11 @@ static int export_sub_cb(const OSSL_PARAM *params, void *varg)
      */
     if (ossl_param_is_empty(params))
         return 1;
-    if (sub_arg->pubenc != NULL
-        && (p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PUB_KEY)) != NULL) {
+    p = ml_kem_import_export_decoder(params);
+    if (sub_arg->pubenc != NULL && p.pubkey != NULL) {
         void *pub = sub_arg->pubenc + sub_arg->puboff;
 
-        if (OSSL_PARAM_get_octet_string(p, &pub, sub_arg->publen, &len) != 1)
+        if (OSSL_PARAM_get_octet_string(p.pubkey, &pub, sub_arg->publen, &len) != 1)
             return 0;
         if (len != sub_arg->publen) {
             ERR_raise_data(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR,
@@ -185,11 +193,10 @@ static int export_sub_cb(const OSSL_PARAM *params, void *varg)
         }
         ++sub_arg->pubcount;
     }
-    if (sub_arg->prvenc != NULL
-        && (p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PRIV_KEY)) != NULL) {
+    if (sub_arg->prvenc != NULL && p.privkey != NULL) {
         void *prv = sub_arg->prvenc + sub_arg->prvoff;
 
-        if (OSSL_PARAM_get_octet_string(p, &prv, sub_arg->prvlen, &len) != 1)
+        if (OSSL_PARAM_get_octet_string(p.privkey, &prv, sub_arg->prvlen, &len) != 1)
             return 0;
         if (len != sub_arg->prvlen) {
             ERR_raise_data(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR,
@@ -319,14 +326,8 @@ static int mlx_kem_export(void *vkey, int selection, OSSL_CALLBACK *param_cb,
 
 static const OSSL_PARAM *mlx_kem_imexport_types(int selection)
 {
-    static const OSSL_PARAM key_types[] = {
-        OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PUB_KEY, NULL, 0),
-        OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PRIV_KEY, NULL, 0),
-        OSSL_PARAM_END
-    };
-
     if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0)
-        return key_types;
+        return ml_kem_import_export_list;
     return NULL;
 }
 
@@ -411,7 +412,7 @@ static int mlx_kem_key_fromdata(MLX_KEY *key,
                                const OSSL_PARAM params[],
                                int include_private)
 {
-    const OSSL_PARAM *param_prv_key = NULL, *param_pub_key;
+    struct ml_kem_import_export_st p;
     const void *pubenc = NULL, *prvenc = NULL;
     size_t pubkey_bytes, prvkey_bytes;
     size_t publen = 0, prvlen = 0;
@@ -422,16 +423,15 @@ static int mlx_kem_key_fromdata(MLX_KEY *key,
     pubkey_bytes = key->minfo->pubkey_bytes + key->xinfo->pubkey_bytes;
     prvkey_bytes = key->minfo->prvkey_bytes + key->xinfo->prvkey_bytes;
 
+    p = ml_kem_import_export_decoder(params);
+
     /* What does the caller want to set? */
-    param_pub_key = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PUB_KEY);
-    if (param_pub_key != NULL &&
-        OSSL_PARAM_get_octet_string_ptr(param_pub_key, &pubenc, &publen) != 1)
+    if (p.pubkey != NULL &&
+        OSSL_PARAM_get_octet_string_ptr(p.pubkey, &pubenc, &publen) != 1)
         return 0;
-    if (include_private)
-        param_prv_key = OSSL_PARAM_locate_const(params,
-                                                OSSL_PKEY_PARAM_PRIV_KEY);
-    if (param_prv_key != NULL &&
-        OSSL_PARAM_get_octet_string_ptr(param_prv_key, &prvenc, &prvlen) != 1)
+    if (include_private
+            && p.privkey != NULL
+            && OSSL_PARAM_get_octet_string_ptr(p.privkey, &prvenc, &prvlen) != 1)
         return 0;
 
     /* The caller MUST specify at least one of the public or private keys. */
@@ -472,19 +472,18 @@ static int mlx_kem_import(void *vkey, int selection, const OSSL_PARAM params[])
     return mlx_kem_key_fromdata(key, params, include_private);
 }
 
+{- produce_param_decoder('mlx_get_params',
+                         (['PKEY_PARAM_BITS',               'bits',     'int'],
+                          ['PKEY_PARAM_SECURITY_BITS',      'secbits',  'int'],
+                          ['PKEY_PARAM_MAX_SIZE',           'maxsize',  'int'],
+                          ['PKEY_PARAM_SECURITY_CATEGORY',  'seccat',   'int'],
+                          ['PKEY_PARAM_ENCODED_PUBLIC_KEY', 'pub',      'octet_string'],
+                          ['PKEY_PARAM_PRIV_KEY',           'priv',     'octet_string'],
+                         )); -}
+
 static const OSSL_PARAM *mlx_kem_gettable_params(void *provctx)
 {
-    static const OSSL_PARAM arr[] = {
-        OSSL_PARAM_int(OSSL_PKEY_PARAM_BITS, NULL),
-        OSSL_PARAM_int(OSSL_PKEY_PARAM_SECURITY_BITS, NULL),
-        OSSL_PARAM_int(OSSL_PKEY_PARAM_MAX_SIZE, NULL),
-        OSSL_PARAM_int(OSSL_PKEY_PARAM_SECURITY_CATEGORY, NULL),
-        OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY, NULL, 0),
-        OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PRIV_KEY, NULL, 0),
-        OSSL_PARAM_END
-    };
-
-    return arr;
+    return mlx_get_params_list;
 }
 
 /*
@@ -493,42 +492,40 @@ static const OSSL_PARAM *mlx_kem_gettable_params(void *provctx)
 static int mlx_kem_get_params(void *vkey, OSSL_PARAM params[])
 {
     MLX_KEY *key = vkey;
-    OSSL_PARAM *p, *pub, *prv = NULL;
+    OSSL_PARAM *pub, *prv = NULL;
     EXPORT_CB_ARG sub_arg;
     int selection;
+    struct mlx_get_params_st p;
     size_t publen = key->minfo->pubkey_bytes + key->xinfo->pubkey_bytes;
     size_t prvlen = key->minfo->prvkey_bytes + key->xinfo->prvkey_bytes;
 
+    p = mlx_get_params_decoder(params);
+
     /* The reported "bit" count is those of the ML-KEM key */
-    p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_BITS);
-    if (p != NULL)
-        if (!OSSL_PARAM_set_int(p, key->minfo->bits))
+    if (p.bits != NULL)
+        if (!OSSL_PARAM_set_int(p.bits, key->minfo->bits))
             return 0;
 
     /* The reported security bits are those of the ML-KEM key */
-    p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_SECURITY_BITS);
-    if (p != NULL)
-        if (!OSSL_PARAM_set_int(p, key->minfo->secbits))
+    if (p.secbits != NULL)
+        if (!OSSL_PARAM_set_int(p.secbits, key->minfo->secbits))
             return 0;
 
     /* The reported security category are those of the ML-KEM key */
-    p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_SECURITY_CATEGORY);
-    if (p != NULL)
-        if (!OSSL_PARAM_set_int(p, key->minfo->security_category))
+    if (p.seccat != NULL)
+        if (!OSSL_PARAM_set_int(p.seccat, key->minfo->security_category))
             return 0;
 
     /* The ciphertext sizes are additive */
-    p = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_MAX_SIZE);
-    if (p != NULL)
-        if (!OSSL_PARAM_set_int(p, key->minfo->ctext_bytes + key->xinfo->pubkey_bytes))
+    if (p.maxsize != NULL)
+        if (!OSSL_PARAM_set_int(p.maxsize, key->minfo->ctext_bytes + key->xinfo->pubkey_bytes))
             return 0;
 
     if (!mlx_kem_have_pubkey(key))
         return 1;
 
     memset(&sub_arg, 0, sizeof(sub_arg));
-    pub = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY);
-    if (pub != NULL) {
+    if ((pub = p.pub) != NULL) {
         if (pub->data_type != OSSL_PARAM_OCTET_STRING)
             return 0;
         pub->return_size = publen;
@@ -545,8 +542,7 @@ static int mlx_kem_get_params(void *vkey, OSSL_PARAM params[])
         }
     }
     if (mlx_kem_have_prvkey(key)) {
-        prv = OSSL_PARAM_locate(params, OSSL_PKEY_PARAM_PRIV_KEY);
-        if (prv != NULL) {
+        if ((prv = p.priv) != NULL) {
             if (prv->data_type != OSSL_PARAM_OCTET_STRING)
                 return 0;
             prv->return_size = prvlen;
@@ -582,29 +578,36 @@ static int mlx_kem_get_params(void *vkey, OSSL_PARAM params[])
     return 1;
 }
 
+{- produce_param_decoder('mlx_set_params',
+                         (['PKEY_PARAM_ENCODED_PUBLIC_KEY', 'pub',   'octet_string'],
+                          ['PKEY_PARAM_PROPERTIES',         'propq', 'utf8_string'],
+                         )); -}
+
 static const OSSL_PARAM *mlx_kem_settable_params(void *provctx)
 {
-    static const OSSL_PARAM arr[] = {
-        OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY, NULL, 0),
-        OSSL_PARAM_END
-    };
-
-    return arr;
+    return mlx_set_params_list;
 }
 
 static int mlx_kem_set_params(void *vkey, const OSSL_PARAM params[])
 {
     MLX_KEY *key = vkey;
-    const OSSL_PARAM *p;
+    struct mlx_set_params_st p;
     const void *pubenc = NULL;
     size_t publen = 0;
 
     if (ossl_param_is_empty(params))
         return 1;
 
-    /* Only one settable parameter is supported */
-    p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY);
-    if (p == NULL)
+    p = mlx_set_params_decoder(params);
+
+    if (p.propq != NULL) {
+        OPENSSL_free(key->propq);
+        key->propq = NULL;
+        if (!OSSL_PARAM_get_utf8_string(p.propq, &key->propq, 0))
+            return 0;
+    }
+
+    if (p.pub == NULL)
         return 1;
 
     /* Key mutation is reportedly generally not allowed */
@@ -615,17 +618,9 @@ static int mlx_kem_set_params(void *vkey, const OSSL_PARAM params[])
         return 0;
     }
     /* An unlikely failure mode is the parameter having some unexpected type */
-    if (!OSSL_PARAM_get_octet_string_ptr(p, &pubenc, &publen))
+    if (!OSSL_PARAM_get_octet_string_ptr(p.pub, &pubenc, &publen))
         return 0;
 
-    p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PROPERTIES);
-    if (p != NULL) {
-        OPENSSL_free(key->propq);
-        key->propq = NULL;
-        if (!OSSL_PARAM_get_utf8_string(p, &key->propq, 0))
-            return 0;
-    }
-
     if (publen != key->minfo->pubkey_bytes + key->xinfo->pubkey_bytes) {
         ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_KEY);
         return 0;
@@ -634,22 +629,27 @@ static int mlx_kem_set_params(void *vkey, const OSSL_PARAM params[])
     return load_keys(key, pubenc, publen, NULL, 0);
 }
 
+{- produce_param_decoder('mlx_gen_set_params',
+                         (['PKEY_PARAM_PROPERTIES', 'propq', 'utf8_string'],
+                         )); -}
+
 static int mlx_kem_gen_set_params(void *vgctx, const OSSL_PARAM params[])
 {
     PROV_ML_KEM_GEN_CTX *gctx = vgctx;
-    const OSSL_PARAM *p;
+    struct mlx_gen_set_params_st p;
 
     if (gctx == NULL)
         return 0;
     if (ossl_param_is_empty(params))
         return 1;
 
-    p = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PROPERTIES);
-    if (p != NULL) {
-        if (p->data_type != OSSL_PARAM_UTF8_STRING)
+    p = mlx_gen_set_params_decoder(params);
+
+    if (p.propq != NULL) {
+        if (p.propq->data_type != OSSL_PARAM_UTF8_STRING)
             return 0;
         OPENSSL_free(gctx->propq);
-        if ((gctx->propq = OPENSSL_strdup(p->data)) == NULL)
+        if ((gctx->propq = OPENSSL_strdup(p.propq->data)) == NULL)
             return 0;
     }
     return 1;
@@ -682,12 +682,7 @@ static void *mlx_kem_gen_init(int evp_type, OSSL_LIB_CTX *libctx,
 static const OSSL_PARAM *mlx_kem_gen_settable_params(ossl_unused void *vgctx,
                                                      ossl_unused void *provctx)
 {
-    static OSSL_PARAM settable[] = {
-        OSSL_PARAM_octet_string(OSSL_PKEY_PARAM_PROPERTIES, NULL, 0),
-        OSSL_PARAM_END
-    };
-
-    return settable;
+    return mlx_gen_set_params_list;
 }
 
 static void *mlx_kem_gen(void *vgctx, OSSL_CALLBACK *osslcb, void *cbarg)