static OSSL_FUNC_keymgmt_gen_fn rsa_gen;
static OSSL_FUNC_keymgmt_gen_cleanup_fn rsa_gen_cleanup;
static OSSL_FUNC_keymgmt_load_fn rsa_load;
+static OSSL_FUNC_keymgmt_load_fn rsapss_load;
static OSSL_FUNC_keymgmt_free_fn rsa_freedata;
static OSSL_FUNC_keymgmt_get_params_fn rsa_get_params;
static OSSL_FUNC_keymgmt_gettable_params_fn rsa_gettable_params;
OPENSSL_free(gctx);
}
-void *rsa_load(const void *reference, size_t reference_sz)
+static void *common_load(const void *reference, size_t reference_sz,
+ int expected_rsa_type)
{
RSA *rsa = NULL;
if (ossl_prov_is_running() && reference_sz == sizeof(rsa)) {
/* The contents of the reference is the address to our object */
rsa = *(RSA **)reference;
+
+ if (RSA_test_flags(rsa, RSA_FLAG_TYPE_MASK) != expected_rsa_type)
+ return NULL;
+
/* We grabbed, so we detach it */
*(RSA **)reference = NULL;
return rsa;
return NULL;
}
+static void *rsa_load(const void *reference, size_t reference_sz)
+{
+ return common_load(reference, reference_sz, RSA_FLAG_TYPE_RSA);
+}
+
+static void *rsapss_load(const void *reference, size_t reference_sz)
+{
+ return common_load(reference, reference_sz, RSA_FLAG_TYPE_RSASSAPSS);
+}
+
/* For any RSA key, we use the "RSA" algorithms regardless of sub-type. */
static const char *rsa_query_operation_name(int operation_id)
{
(void (*)(void))rsapss_gen_settable_params },
{ OSSL_FUNC_KEYMGMT_GEN, (void (*)(void))rsa_gen },
{ OSSL_FUNC_KEYMGMT_GEN_CLEANUP, (void (*)(void))rsa_gen_cleanup },
- { OSSL_FUNC_KEYMGMT_LOAD, (void (*)(void))rsa_load },
+ { OSSL_FUNC_KEYMGMT_LOAD, (void (*)(void))rsapss_load },
{ OSSL_FUNC_KEYMGMT_FREE, (void (*)(void))rsa_freedata },
{ OSSL_FUNC_KEYMGMT_GET_PARAMS, (void (*) (void))rsa_get_params },
{ OSSL_FUNC_KEYMGMT_GETTABLE_PARAMS, (void (*) (void))rsa_gettable_params },