]> git.ipfire.org Git - thirdparty/openssl.git/blobdiff - crypto/rsa/rsa_backend.c
RSA: Add a less loaded PSS-parameter structure
[thirdparty/openssl.git] / crypto / rsa / rsa_backend.c
index cf0bff0822e284554f10bdec45a6f0ef5805affc..7497a8579c994a65b4844ce3aa9dbe39eb6e4fb6 100644 (file)
@@ -7,10 +7,16 @@
  * https://www.openssl.org/source/license.html
  */
 
+#include <string.h>
 #include <openssl/core_names.h>
 #include <openssl/params.h>
+#include <openssl/evp.h>
+#include "internal/sizes.h"
+#include "internal/param_build_set.h"
 #include "crypto/rsa.h"
 
+#include "e_os.h"                /* strcasecmp for Windows() */
+
 /*
  * The intention with the "backend" source file is to offer backend support
  * for legacy backends (EVP_PKEY_ASN1_METHOD and EVP_PKEY_METHOD) and provider
@@ -146,3 +152,139 @@ int rsa_todata(RSA *rsa, OSSL_PARAM_BLD *bld, OSSL_PARAM params[])
     sk_BIGNUM_const_free(coeffs);
     return ret;
 }
+
+int rsa_pss_params_30_todata(const RSA_PSS_PARAMS_30 *pss, const char *propq,
+                             OSSL_PARAM_BLD *bld, OSSL_PARAM params[])
+{
+    if (!rsa_pss_params_30_is_unrestricted(pss)) {
+        int hashalg_nid = rsa_pss_params_30_hashalg(pss);
+        int maskgenalg_nid = rsa_pss_params_30_maskgenalg(pss);
+        int maskgenhashalg_nid = rsa_pss_params_30_maskgenhashalg(pss);
+        int saltlen = rsa_pss_params_30_saltlen(pss);
+        int default_hashalg_nid = rsa_pss_params_30_hashalg(NULL);
+        int default_maskgenalg_nid = rsa_pss_params_30_maskgenalg(NULL);
+        int default_maskgenhashalg_nid = rsa_pss_params_30_maskgenhashalg(NULL);
+        const char *mdname =
+            (hashalg_nid == default_hashalg_nid
+             ? NULL : rsa_oaeppss_nid2name(hashalg_nid));
+        const char *mgfname =
+            (maskgenalg_nid == default_maskgenalg_nid
+             ? NULL : rsa_oaeppss_nid2name(maskgenalg_nid));
+        const char *mgf1mdname =
+            (maskgenhashalg_nid == default_maskgenhashalg_nid
+             ? NULL : rsa_oaeppss_nid2name(maskgenhashalg_nid));
+        const char *key_md = OSSL_PKEY_PARAM_RSA_DIGEST;
+        const char *key_mgf = OSSL_PKEY_PARAM_RSA_MASKGENFUNC;
+        const char *key_mgf1_md = OSSL_PKEY_PARAM_RSA_MGF1_DIGEST;
+        const char *key_saltlen = OSSL_PKEY_PARAM_RSA_PSS_SALTLEN;
+
+        /*
+         * To ensure that the key isn't seen as unrestricted by the recipient,
+         * we make sure that at least one PSS-related parameter is passed, even
+         * if it has a default value; saltlen.
+         */
+        if ((mdname != NULL
+             && !ossl_param_build_set_utf8_string(bld, params, key_md, mdname))
+            || (mgfname != NULL
+                && !ossl_param_build_set_utf8_string(bld, params,
+                                                     key_mgf, mgfname))
+            || (mgf1mdname != NULL
+                && !ossl_param_build_set_utf8_string(bld, params,
+                                                     key_mgf1_md, mgf1mdname))
+            || (!ossl_param_build_set_int(bld, params, key_saltlen, saltlen)))
+            return 0;
+    }
+    return 1;
+}
+
+int rsa_pss_params_30_fromdata(RSA_PSS_PARAMS_30 *pss_params,
+                               const OSSL_PARAM params[], OPENSSL_CTX *libctx)
+{
+    const OSSL_PARAM *param_md, *param_mgf, *param_mgf1md,  *param_saltlen;
+    EVP_MD *md = NULL, *mgf1md = NULL;
+    int saltlen;
+    int ret = 0;
+
+    if (pss_params == NULL)
+        return 0;
+
+    param_md =
+        OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_RSA_DIGEST);
+    param_mgf =
+        OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_RSA_MASKGENFUNC);
+    param_mgf1md =
+        OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_RSA_MGF1_DIGEST);
+    param_saltlen =
+        OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_RSA_PSS_SALTLEN);
+
+    /*
+     * If we get any of the parameters, we know we have at least some
+     * restrictions, so we start by setting default values, and let each
+     * parameter override their specific restriction data.
+     */
+    if (param_md != NULL || param_mgf != NULL || param_mgf1md != NULL
+        || param_saltlen != NULL)
+        if (!rsa_pss_params_30_set_defaults(pss_params))
+            return 0;
+
+    if (param_mgf != NULL) {
+        int default_maskgenalg_nid = rsa_pss_params_30_maskgenalg(NULL);
+        const char *mgfname = NULL;
+
+        if (param_mgf->data_type == OSSL_PARAM_UTF8_STRING)
+            mgfname = param_mgf->data;
+        else if (!OSSL_PARAM_get_utf8_ptr(param_mgf, &mgfname))
+            return 0;
+
+        /* TODO Revisit this if / when a new MGF algorithm appears */
+        if (strcasecmp(param_mgf->data,
+                       rsa_mgf_nid2name(default_maskgenalg_nid)) != 0)
+            return 0;
+    }
+
+    /*
+     * We're only interested in the NIDs that correspond to the MDs, so the
+     * exact propquery is unimportant in the EVP_MD_fetch() calls below.
+     */
+
+    if (param_md != NULL) {
+        const char *mdname = NULL;
+
+        if (param_md->data_type == OSSL_PARAM_UTF8_STRING)
+            mdname = param_md->data;
+        else if (!OSSL_PARAM_get_utf8_ptr(param_mgf, &mdname))
+            goto err;
+
+        if ((md = EVP_MD_fetch(libctx, mdname, NULL)) == NULL
+            || !rsa_pss_params_30_set_hashalg(pss_params,
+                                              rsa_oaeppss_md2nid(md)))
+            goto err;
+    }
+
+    if (param_mgf1md != NULL) {
+        const char *mgf1mdname = NULL;
+
+        if (param_mgf1md->data_type == OSSL_PARAM_UTF8_STRING)
+            mgf1mdname = param_mgf1md->data;
+        else if (!OSSL_PARAM_get_utf8_ptr(param_mgf, &mgf1mdname))
+            goto err;
+
+        if ((mgf1md = EVP_MD_fetch(libctx, mgf1mdname, NULL)) == NULL
+            || !rsa_pss_params_30_set_maskgenhashalg(pss_params,
+                                                     rsa_oaeppss_md2nid(mgf1md)))
+            goto err;
+    }
+
+    if (param_saltlen != NULL) {
+        if (!OSSL_PARAM_get_int(param_saltlen, &saltlen)
+            || !rsa_pss_params_30_set_saltlen(pss_params, saltlen))
+            goto err;
+    }
+
+    ret = 1;
+
+ err:
+    EVP_MD_free(md);
+    EVP_MD_free(mgf1md);
+    return ret;
+}