]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
PROV: Add RSA-PSS specific OSSL_FUNC_KEYMGMT_LOAD function
authorRichard Levitte <levitte@openssl.org>
Fri, 29 Jan 2021 03:47:47 +0000 (04:47 +0100)
committerRichard Levitte <levitte@openssl.org>
Fri, 19 Mar 2021 15:46:39 +0000 (16:46 +0100)
The OSSL_FUNC_KEYMGMT_LOAD function for both plain RSA and RSA-PSS
keys now also check that the key to be loaded is the correct type,
and refuse to load it if it's not.

Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/14314)

providers/implementations/keymgmt/rsa_kmgmt.c

index 1c4fb3bcd5405dea812de19b08da4657458fb005..394f3836dd71434de36ba3d42051372839386594 100644 (file)
@@ -36,6 +36,7 @@ static OSSL_FUNC_keymgmt_gen_settable_params_fn rsapss_gen_settable_params;
 static OSSL_FUNC_keymgmt_gen_fn rsa_gen;
 static OSSL_FUNC_keymgmt_gen_cleanup_fn rsa_gen_cleanup;
 static OSSL_FUNC_keymgmt_load_fn rsa_load;
+static OSSL_FUNC_keymgmt_load_fn rsapss_load;
 static OSSL_FUNC_keymgmt_free_fn rsa_freedata;
 static OSSL_FUNC_keymgmt_get_params_fn rsa_get_params;
 static OSSL_FUNC_keymgmt_gettable_params_fn rsa_gettable_params;
@@ -610,13 +611,18 @@ static void rsa_gen_cleanup(void *genctx)
     OPENSSL_free(gctx);
 }
 
-void *rsa_load(const void *reference, size_t reference_sz)
+static void *common_load(const void *reference, size_t reference_sz,
+                         int expected_rsa_type)
 {
     RSA *rsa = NULL;
 
     if (ossl_prov_is_running() && reference_sz == sizeof(rsa)) {
         /* The contents of the reference is the address to our object */
         rsa = *(RSA **)reference;
+
+        if (RSA_test_flags(rsa, RSA_FLAG_TYPE_MASK) != expected_rsa_type)
+            return NULL;
+
         /* We grabbed, so we detach it */
         *(RSA **)reference = NULL;
         return rsa;
@@ -624,6 +630,16 @@ void *rsa_load(const void *reference, size_t reference_sz)
     return NULL;
 }
 
+static void *rsa_load(const void *reference, size_t reference_sz)
+{
+    return common_load(reference, reference_sz, RSA_FLAG_TYPE_RSA);
+}
+
+static void *rsapss_load(const void *reference, size_t reference_sz)
+{
+    return common_load(reference, reference_sz, RSA_FLAG_TYPE_RSASSAPSS);
+}
+
 /* For any RSA key, we use the "RSA" algorithms regardless of sub-type. */
 static const char *rsa_query_operation_name(int operation_id)
 {
@@ -661,7 +677,7 @@ const OSSL_DISPATCH ossl_rsapss_keymgmt_functions[] = {
       (void (*)(void))rsapss_gen_settable_params },
     { OSSL_FUNC_KEYMGMT_GEN, (void (*)(void))rsa_gen },
     { OSSL_FUNC_KEYMGMT_GEN_CLEANUP, (void (*)(void))rsa_gen_cleanup },
-    { OSSL_FUNC_KEYMGMT_LOAD, (void (*)(void))rsa_load },
+    { OSSL_FUNC_KEYMGMT_LOAD, (void (*)(void))rsapss_load },
     { OSSL_FUNC_KEYMGMT_FREE, (void (*)(void))rsa_freedata },
     { OSSL_FUNC_KEYMGMT_GET_PARAMS, (void (*) (void))rsa_get_params },
     { OSSL_FUNC_KEYMGMT_GETTABLE_PARAMS, (void (*) (void))rsa_gettable_params },