]> git.ipfire.org Git - thirdparty/strongswan.git/blobdiff - src/libstrongswan/plugins/pkcs11/pkcs11_private_key.c
pkcs11: Optionally hash data for PKCS#1 v1.5 RSA signatures in software
[thirdparty/strongswan.git] / src / libstrongswan / plugins / pkcs11 / pkcs11_private_key.c
index 3154460e187874a531c3b11c56c1dc7e7f9017f3..6b8be62658c2bf522122597f1bfdbcf5ba32f0f5 100644 (file)
@@ -1,4 +1,7 @@
 /*
+ * Copyright (C) 2011-2016 Tobias Brunner
+ * HSR Hochschule fuer Technik Rapperswil
+ *
  * Copyright (C) 2010 Martin Willi
  * Copyright (C) 2010 revosec AG
  *
  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  * for more details.
  */
+/*
+ * Copyright (C) 2016 EDF S.A.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
 
 #include "pkcs11_private_key.h"
 
 #include "pkcs11_library.h"
 #include "pkcs11_manager.h"
+#include "pkcs11_public_key.h"
 
-#include <debug.h>
-#include <threading/mutex.h>
+#include <utils/debug.h>
+#include <asn1/asn1.h>
 
 typedef struct private_pkcs11_private_key_t private_pkcs11_private_key_t;
 
@@ -39,14 +64,14 @@ struct private_pkcs11_private_key_t {
        pkcs11_library_t *lib;
 
        /**
-        * Token session
+        * Slot the token is in
         */
-       CK_SESSION_HANDLE session;
+       CK_SLOT_ID slot;
 
        /**
-        * Mutex to lock session
+        * Token session
         */
-       mutex_t *mutex;
+       CK_SESSION_HANDLE session;
 
        /**
         * Key object on the token
@@ -72,12 +97,18 @@ struct private_pkcs11_private_key_t {
         * References to this key
         */
        refcount_t ref;
+
+       /**
+        * Type of this private key
+        */
+       key_type_t type;
 };
 
+
 METHOD(private_key_t, get_type, key_type_t,
        private_pkcs11_private_key_t *this)
 {
-       return this->pubkey->get_type(this->pubkey);
+       return this->type;
 }
 
 METHOD(private_key_t, get_keysize, int,
@@ -87,20 +118,79 @@ METHOD(private_key_t, get_keysize, int,
 }
 
 /**
- * See header.
+ * Check if a token supports the given mechanism.
+ */
+static bool is_mechanism_supported(pkcs11_library_t *p11, CK_SLOT_ID slot,
+                                                                  const CK_MECHANISM_PTR mech)
+{
+       enumerator_t *mechs;
+       CK_MECHANISM_TYPE type;
+
+       mechs = p11->create_mechanism_enumerator(p11, slot);
+       while (mechs->enumerate(mechs, &type, NULL))
+       {
+               if (type == mech->mechanism)
+               {
+                       mechs->destroy(mechs);
+                       return TRUE;
+               }
+       }
+       mechs->destroy(mechs);
+       return FALSE;
+}
+
+/*
+ * Described in header
  */
-CK_MECHANISM_PTR pkcs11_signature_scheme_to_mech(signature_scheme_t scheme)
+CK_MECHANISM_PTR pkcs11_signature_scheme_to_mech(pkcs11_library_t *p11,
+                                                                                                CK_SLOT_ID slot,
+                                                                                                signature_scheme_t scheme,
+                                                                                                key_type_t type, size_t keylen,
+                                                                                                hash_algorithm_t *hash)
 {
        static struct {
                signature_scheme_t scheme;
                CK_MECHANISM mechanism;
+               key_type_t type;
+               size_t keylen;
+               hash_algorithm_t hash;
        } mappings[] = {
-               {SIGN_RSA_EMSA_PKCS1_NULL,              {CKM_RSA_PKCS,                          NULL, 0}},
-               {SIGN_RSA_EMSA_PKCS1_SHA1,              {CKM_SHA1_RSA_PKCS,                     NULL, 0}},
-               {SIGN_RSA_EMSA_PKCS1_SHA256,    {CKM_SHA256_RSA_PKCS,           NULL, 0}},
-               {SIGN_RSA_EMSA_PKCS1_SHA384,    {CKM_SHA384_RSA_PKCS,           NULL, 0}},
-               {SIGN_RSA_EMSA_PKCS1_SHA512,    {CKM_SHA512_RSA_PKCS,           NULL, 0}},
-               {SIGN_RSA_EMSA_PKCS1_MD5,               {CKM_MD5_RSA_PKCS,                      NULL, 0}},
+               {SIGN_RSA_EMSA_PKCS1_NULL,              {CKM_RSA_PKCS,                  NULL, 0},
+                KEY_RSA, 0,                                                                       HASH_UNKNOWN},
+               {SIGN_RSA_EMSA_PKCS1_SHA2_256,  {CKM_SHA256_RSA_PKCS,   NULL, 0},
+                KEY_RSA, 0,                                                                       HASH_UNKNOWN},
+               {SIGN_RSA_EMSA_PKCS1_SHA2_256,  {CKM_RSA_PKCS,                  NULL, 0},
+                KEY_RSA, 0,                                                                            HASH_SHA256},
+               {SIGN_RSA_EMSA_PKCS1_SHA2_384,  {CKM_SHA384_RSA_PKCS,   NULL, 0},
+                KEY_RSA, 0,                                                                       HASH_UNKNOWN},
+               {SIGN_RSA_EMSA_PKCS1_SHA2_384,  {CKM_RSA_PKCS,                  NULL, 0},
+                KEY_RSA, 0,                                                                            HASH_SHA384},
+               {SIGN_RSA_EMSA_PKCS1_SHA2_512,  {CKM_SHA512_RSA_PKCS,   NULL, 0},
+                KEY_RSA, 0,                                                                       HASH_UNKNOWN},
+               {SIGN_RSA_EMSA_PKCS1_SHA2_512,  {CKM_RSA_PKCS,                  NULL, 0},
+                KEY_RSA, 0,                                                                            HASH_SHA512},
+               {SIGN_RSA_EMSA_PKCS1_SHA1,              {CKM_SHA1_RSA_PKCS,             NULL, 0},
+                KEY_RSA, 0,                                                                       HASH_UNKNOWN},
+               {SIGN_RSA_EMSA_PKCS1_SHA1,              {CKM_RSA_PKCS,                  NULL, 0},
+                KEY_RSA, 0,                                                                              HASH_SHA1},
+               {SIGN_RSA_EMSA_PKCS1_MD5,               {CKM_MD5_RSA_PKCS,              NULL, 0},
+                KEY_RSA, 0,                                                                       HASH_UNKNOWN},
+               {SIGN_ECDSA_WITH_NULL,                  {CKM_ECDSA,                             NULL, 0},
+                KEY_ECDSA, 0,                                                                     HASH_UNKNOWN},
+               {SIGN_ECDSA_WITH_SHA1_DER,              {CKM_ECDSA_SHA1,                NULL, 0},
+                KEY_ECDSA, 0,                                                                     HASH_UNKNOWN},
+               {SIGN_ECDSA_WITH_SHA256_DER,    {CKM_ECDSA,                             NULL, 0},
+                KEY_ECDSA, 0,                                                                          HASH_SHA256},
+               {SIGN_ECDSA_WITH_SHA384_DER,    {CKM_ECDSA,                             NULL, 0},
+                KEY_ECDSA, 0,                                                                          HASH_SHA384},
+               {SIGN_ECDSA_WITH_SHA512_DER,    {CKM_ECDSA,                             NULL, 0},
+                KEY_ECDSA, 0,                                                                          HASH_SHA512},
+               {SIGN_ECDSA_256,                                {CKM_ECDSA,                             NULL, 0},
+                KEY_ECDSA, 256,                                                                        HASH_SHA256},
+               {SIGN_ECDSA_384,                                {CKM_ECDSA,                             NULL, 0},
+                KEY_ECDSA, 384,                                                                        HASH_SHA384},
+               {SIGN_ECDSA_521,                                {CKM_ECDSA,                             NULL, 0},
+                KEY_ECDSA, 521,                                                                        HASH_SHA512},
        };
        int i;
 
@@ -108,6 +198,17 @@ CK_MECHANISM_PTR pkcs11_signature_scheme_to_mech(signature_scheme_t scheme)
        {
                if (mappings[i].scheme == scheme)
                {
+                       size_t len = mappings[i].keylen;
+
+                       if (mappings[i].type != type || (len && keylen != len) ||
+                               !is_mechanism_supported(p11, slot, &mappings[i].mechanism))
+                       {
+                               continue;
+                       }
+                       if (hash)
+                       {
+                               *hash = mappings[i].hash;
+                       }
                        return &mappings[i].mechanism;
                }
        }
@@ -141,7 +242,8 @@ CK_MECHANISM_PTR pkcs11_encryption_scheme_to_mech(encryption_scheme_t scheme)
 /**
  * Reauthenticate to do a signature
  */
-static bool reauth(private_pkcs11_private_key_t *this)
+static bool reauth(private_pkcs11_private_key_t *this,
+                                  CK_SESSION_HANDLE session)
 {
        enumerator_t *enumerator;
        shared_key_t *shared;
@@ -155,7 +257,7 @@ static bool reauth(private_pkcs11_private_key_t *this)
        {
                found = TRUE;
                pin = shared->get_key(shared);
-               rv = this->lib->f->C_Login(this->session, CKU_CONTEXT_SPECIFIC,
+               rv = this->lib->f->C_Login(session, CKU_CONTEXT_SPECIFIC,
                                                                   pin.ptr, pin.len);
                if (rv == CKR_OK)
                {
@@ -175,44 +277,112 @@ static bool reauth(private_pkcs11_private_key_t *this)
 }
 
 METHOD(private_key_t, sign, bool,
-       private_pkcs11_private_key_t *this, signature_scheme_t scheme,
+       private_pkcs11_private_key_t *this, signature_scheme_t scheme, void *params,
        chunk_t data, chunk_t *signature)
 {
        CK_MECHANISM_PTR mechanism;
+       CK_SESSION_HANDLE session;
        CK_BYTE_PTR buf;
        CK_ULONG len;
        CK_RV rv;
+       hash_algorithm_t hash_alg;
+       chunk_t hash = chunk_empty;
 
-       mechanism = pkcs11_signature_scheme_to_mech(scheme);
+       mechanism = pkcs11_signature_scheme_to_mech(this->lib, this->slot, scheme,
+                                                                                               this->type, get_keysize(this),
+                                                                                               &hash_alg);
        if (!mechanism)
        {
                DBG1(DBG_LIB, "signature scheme %N not supported",
                         signature_scheme_names, scheme);
                return FALSE;
        }
-       this->mutex->lock(this->mutex);
-       rv = this->lib->f->C_SignInit(this->session, mechanism, this->object);
-       if (this->reauth && !reauth(this))
+       rv = this->lib->f->C_OpenSession(this->slot, CKF_SERIAL_SESSION, NULL, NULL,
+                                                                        &session);
+       if (rv != CKR_OK)
+       {
+               DBG1(DBG_CFG, "opening PKCS#11 session failed: %N", ck_rv_names, rv);
+               return FALSE;
+       }
+       rv = this->lib->f->C_SignInit(session, mechanism, this->object);
+       if (this->reauth && !reauth(this, session))
        {
+               this->lib->f->C_CloseSession(session);
                return FALSE;
        }
        if (rv != CKR_OK)
        {
-               this->mutex->unlock(this->mutex);
+               this->lib->f->C_CloseSession(session);
                DBG1(DBG_LIB, "C_SignInit() failed: %N", ck_rv_names, rv);
                return FALSE;
        }
+       if (hash_alg != HASH_UNKNOWN)
+       {
+               hasher_t *hasher;
+
+               hasher = lib->crypto->create_hasher(lib->crypto, hash_alg);
+               if (!hasher || !hasher->allocate_hash(hasher, data, &hash))
+               {
+                       DESTROY_IF(hasher);
+                       this->lib->f->C_CloseSession(session);
+                       return FALSE;
+               }
+               hasher->destroy(hasher);
+               switch (scheme)
+               {
+                       case SIGN_RSA_EMSA_PKCS1_SHA1:
+                       case SIGN_RSA_EMSA_PKCS1_SHA2_256:
+                       case SIGN_RSA_EMSA_PKCS1_SHA2_384:
+                       case SIGN_RSA_EMSA_PKCS1_SHA2_512:
+                               /* encode PKCS#1 digestInfo if the token does not support it */
+                               hash = asn1_wrap(ASN1_SEQUENCE, "mm",
+                                                                asn1_algorithmIdentifier(
+                                                                       hasher_algorithm_to_oid(hash_alg)),
+                                                                asn1_wrap(ASN1_OCTET_STRING, "m", hash));
+                               break;
+                       default:
+                               break;
+               }
+               data = hash;
+       }
        len = (get_keysize(this) + 7) / 8;
+       if (this->type == KEY_ECDSA)
+       {       /* signature is twice the length of the base point order */
+               len *= 2;
+       }
        buf = malloc(len);
-       rv = this->lib->f->C_Sign(this->session, data.ptr, data.len, buf, &len);
-       this->mutex->unlock(this->mutex);
+       rv = this->lib->f->C_Sign(session, data.ptr, data.len, buf, &len);
+       this->lib->f->C_CloseSession(session);
+       chunk_free(&hash);
        if (rv != CKR_OK)
        {
                DBG1(DBG_LIB, "C_Sign() failed: %N", ck_rv_names, rv);
                free(buf);
                return FALSE;
        }
-       *signature = chunk_create(buf, len);
+       switch (scheme)
+       {
+               case SIGN_ECDSA_WITH_SHA1_DER:
+               case SIGN_ECDSA_WITH_SHA256_DER:
+               case SIGN_ECDSA_WITH_SHA384_DER:
+               case SIGN_ECDSA_WITH_SHA512_DER:
+               {
+                       chunk_t r, s;
+
+                       /* return an ASN.1 encoded sequence of integers r and s, removing
+                        * any zero-padding */
+                       len /= 2;
+                       r = chunk_skip_zero(chunk_create(buf, len));
+                       s = chunk_skip_zero(chunk_create(buf+len, len));
+                       *signature = asn1_wrap(ASN1_SEQUENCE, "mm",
+                                                                  asn1_integer("c", r), asn1_integer("c", s));
+                       free(buf);
+                       break;
+               }
+               default:
+                       *signature = chunk_create(buf, len);
+                       break;
+       }
        return TRUE;
 }
 
@@ -221,6 +391,7 @@ METHOD(private_key_t, decrypt, bool,
        chunk_t crypt, chunk_t *plain)
 {
        CK_MECHANISM_PTR mechanism;
+       CK_SESSION_HANDLE session;
        CK_BYTE_PTR buf;
        CK_ULONG len;
        CK_RV rv;
@@ -232,22 +403,29 @@ METHOD(private_key_t, decrypt, bool,
                         encryption_scheme_names, scheme);
                return FALSE;
        }
-       this->mutex->lock(this->mutex);
-       rv = this->lib->f->C_DecryptInit(this->session, mechanism, this->object);
-       if (this->reauth && !reauth(this))
+       rv = this->lib->f->C_OpenSession(this->slot, CKF_SERIAL_SESSION, NULL, NULL,
+                                                                        &session);
+       if (rv != CKR_OK)
+       {
+               DBG1(DBG_CFG, "opening PKCS#11 session failed: %N", ck_rv_names, rv);
+               return FALSE;
+       }
+       rv = this->lib->f->C_DecryptInit(session, mechanism, this->object);
+       if (this->reauth && !reauth(this, session))
        {
+               this->lib->f->C_CloseSession(session);
                return FALSE;
        }
        if (rv != CKR_OK)
        {
-               this->mutex->unlock(this->mutex);
+               this->lib->f->C_CloseSession(session);
                DBG1(DBG_LIB, "C_DecryptInit() failed: %N", ck_rv_names, rv);
                return FALSE;
        }
        len = (get_keysize(this) + 7) / 8;
        buf = malloc(len);
-       rv = this->lib->f->C_Decrypt(this->session, crypt.ptr, crypt.len, buf, &len);
-       this->mutex->unlock(this->mutex);
+       rv = this->lib->f->C_Decrypt(session, crypt.ptr, crypt.len, buf, &len);
+       this->lib->f->C_CloseSession(session);
        if (rv != CKR_OK)
        {
                DBG1(DBG_LIB, "C_Decrypt() failed: %N", ck_rv_names, rv);
@@ -294,7 +472,6 @@ METHOD(private_key_t, destroy, void,
                {
                        this->pubkey->destroy(this->pubkey);
                }
-               this->mutex->destroy(this->mutex);
                this->keyid->destroy(this->keyid);
                this->lib->f->C_CloseSession(this->session);
                free(this);
@@ -332,7 +509,8 @@ static pkcs11_library_t* find_lib(char *module)
 /**
  * Find the PKCS#11 lib having a keyid, and optionally a slot
  */
-static pkcs11_library_t* find_lib_by_keyid(chunk_t keyid, int *slot)
+static pkcs11_library_t* find_lib_by_keyid(chunk_t keyid, int *slot,
+                                                                                  CK_OBJECT_CLASS class)
 {
        pkcs11_manager_t *manager;
        enumerator_t *enumerator;
@@ -349,8 +527,7 @@ static pkcs11_library_t* find_lib_by_keyid(chunk_t keyid, int *slot)
        {
                if (*slot == -1 || *slot == current)
                {
-                       /* we look for a public key, it is usually readable without login */
-                       CK_OBJECT_CLASS class = CKO_PUBLIC_KEY;
+                       /* look for a pubkey/cert, it is usually readable without login */
                        CK_ATTRIBUTE tmpl[] = {
                                {CKA_CLASS, &class, sizeof(class)},
                                {CKA_ID, keyid.ptr, keyid.len},
@@ -389,6 +566,120 @@ static pkcs11_library_t* find_lib_by_keyid(chunk_t keyid, int *slot)
        return found;
 }
 
+/**
+ * Find the PKCS#11 lib and CKA_ID of the certificate object of a given
+ * subjectKeyIdentifier and optional slot
+ */
+static pkcs11_library_t* find_lib_and_keyid_by_skid(chunk_t keyid_chunk,
+                                                                                                       chunk_t *ckaid, int *slot)
+{
+       CK_OBJECT_CLASS class = CKO_CERTIFICATE;
+       CK_CERTIFICATE_TYPE type = CKC_X_509;
+       CK_ATTRIBUTE tmpl[] = {
+               {CKA_CLASS, &class, sizeof(class)},
+               {CKA_CERTIFICATE_TYPE, &type, sizeof(type)},
+       };
+       CK_ATTRIBUTE attr[] = {
+               {CKA_VALUE, NULL, 0},
+               {CKA_ID, NULL, 0},
+       };
+       CK_OBJECT_HANDLE object;
+       CK_SESSION_HANDLE session;
+       CK_RV rv;
+       pkcs11_manager_t *manager;
+       enumerator_t *enumerator, *certs;
+       identification_t *keyid;
+       pkcs11_library_t *p11, *found = NULL;
+       CK_SLOT_ID current;
+       linked_list_t *raw;
+       certificate_t *cert;
+       struct {
+               chunk_t value;
+               chunk_t ckaid;
+       } *entry;
+
+       manager = lib->get(lib, "pkcs11-manager");
+       if (!manager)
+       {
+               return NULL;
+       }
+
+       keyid = identification_create_from_encoding(ID_KEY_ID, keyid_chunk);
+       /* store result in a temporary list, avoid recursive operation */
+       raw = linked_list_create();
+
+       enumerator = manager->create_token_enumerator(manager);
+       while (enumerator->enumerate(enumerator, &p11, &current))
+       {
+               if (*slot != -1 && *slot != current)
+               {
+                       continue;
+               }
+               rv = p11->f->C_OpenSession(current, CKF_SERIAL_SESSION, NULL, NULL,
+                                                                  &session);
+               if (rv != CKR_OK)
+               {
+                       DBG1(DBG_CFG, "opening PKCS#11 session failed: %N",
+                                ck_rv_names, rv);
+                       continue;
+               }
+               certs = p11->create_object_enumerator(p11, session, tmpl, countof(tmpl),
+                                                                                         attr, countof(attr));
+               while (certs->enumerate(certs, &object))
+               {
+                       INIT(entry,
+                               .value = chunk_clone(
+                                                       chunk_create(attr[0].pValue, attr[0].ulValueLen)),
+                               .ckaid = chunk_clone(
+                                                       chunk_create(attr[1].pValue, attr[1].ulValueLen)),
+                       );
+                       raw->insert_last(raw, entry);
+               }
+               certs->destroy(certs);
+
+               while (raw->remove_first(raw, (void**)&entry) == SUCCESS)
+               {
+                       if (!found)
+                       {
+                               cert = lib->creds->create(lib->creds, CRED_CERTIFICATE,
+                                                                                 CERT_X509, BUILD_BLOB_ASN1_DER,
+                                                                                 entry->value, BUILD_END);
+                               if (cert)
+                               {
+                                       if (cert->has_subject(cert, keyid))
+                                       {
+                                               DBG1(DBG_CFG, "found cert with keyid '%#B' on PKCS#11 "
+                                                        "token '%s':%d", &keyid_chunk, p11->get_name(p11),
+                                                        current);
+                                               found = p11;
+                                               *ckaid = chunk_clone(entry->ckaid);
+                                               *slot = current;
+                                       }
+                                       cert->destroy(cert);
+                               }
+                               else
+                               {
+                                       DBG1(DBG_CFG, "parsing cert with CKA_ID '%#B' on PKCS#11 "
+                                                "token '%s':%d failed", &entry->ckaid,
+                                                p11->get_name(p11), current);
+                               }
+                       }
+                       chunk_free(&entry->value);
+                       chunk_free(&entry->ckaid);
+                       free(entry);
+               }
+               p11->f->C_CloseSession(session);
+               if (found)
+               {
+                       break;
+               }
+       }
+       enumerator->destroy(enumerator);
+       keyid->destroy(keyid);
+       raw->destroy(raw);
+       return found;
+}
+
 /**
  * Find the key on the token
  */
@@ -404,13 +695,11 @@ static bool find_key(private_pkcs11_private_key_t *this, chunk_t keyid)
        CK_BBOOL reauth = FALSE;
        CK_ATTRIBUTE attr[] = {
                {CKA_KEY_TYPE, &type, sizeof(type)},
-               {CKA_MODULUS, NULL, 0},
-               {CKA_PUBLIC_EXPONENT, NULL, 0},
                {CKA_ALWAYS_AUTHENTICATE, &reauth, sizeof(reauth)},
        };
        enumerator_t *enumerator;
-       chunk_t modulus, pubexp;
        int count = countof(attr);
+       bool found = FALSE;
 
        /* do not use CKA_ALWAYS_AUTHENTICATE if not supported */
        if (!(this->lib->get_features(this->lib) & PKCS11_ALWAYS_AUTH_KEYS))
@@ -421,26 +710,16 @@ static bool find_key(private_pkcs11_private_key_t *this, chunk_t keyid)
                                                        this->session, tmpl, countof(tmpl), attr, count);
        if (enumerator->enumerate(enumerator, &object))
        {
+               this->type = KEY_RSA;
                switch (type)
                {
+                       case CKK_ECDSA:
+                               this->type = KEY_ECDSA;
+                               /* fall-through */
                        case CKK_RSA:
-                               if (attr[1].ulValueLen == -1 || attr[2].ulValueLen == -1)
-                               {
-                                       DBG1(DBG_CFG, "reading modulus/exponent from PKCS#1 failed");
-                                       break;
-                               }
-                               modulus = chunk_create(attr[1].pValue, attr[1].ulValueLen);
-                               pubexp = chunk_create(attr[2].pValue, attr[2].ulValueLen);
-                               this->pubkey = lib->creds->create(lib->creds, CRED_PUBLIC_KEY,
-                                                                       KEY_RSA, BUILD_RSA_MODULUS, modulus,
-                                                                       BUILD_RSA_PUB_EXP, pubexp, BUILD_END);
-                               if (!this->pubkey)
-                               {
-                                       DBG1(DBG_CFG, "extracting public key from PKCS#11 RSA "
-                                                "private key failed");
-                               }
                                this->reauth = reauth;
                                this->object = object;
+                               found = TRUE;
                                break;
                        default:
                                DBG1(DBG_CFG, "PKCS#11 key type %d not supported", type);
@@ -448,7 +727,7 @@ static bool find_key(private_pkcs11_private_key_t *this, chunk_t keyid)
                }
        }
        enumerator->destroy(enumerator);
-       return this->pubkey != NULL;
+       return found;
 }
 
 /**
@@ -500,6 +779,50 @@ static bool login(private_pkcs11_private_key_t *this, int slot)
        return success;
 }
 
+/**
+ * Get a public key from a certificate with a given key ID.
+ */
+static public_key_t* find_pubkey_in_certs(private_pkcs11_private_key_t *this,
+                                                                                 chunk_t keyid)
+{
+       CK_OBJECT_CLASS class = CKO_CERTIFICATE;
+       CK_CERTIFICATE_TYPE type = CKC_X_509;
+       CK_ATTRIBUTE tmpl[] = {
+               {CKA_CLASS, &class, sizeof(class)},
+               {CKA_CERTIFICATE_TYPE, &type, sizeof(type)},
+               {CKA_ID, keyid.ptr, keyid.len},
+       };
+       CK_OBJECT_HANDLE object;
+       CK_ATTRIBUTE attr[] = {
+               {CKA_VALUE, NULL, 0},
+       };
+       enumerator_t *enumerator;
+       chunk_t data = chunk_empty;
+       public_key_t *key = NULL;
+       certificate_t *cert;
+
+       enumerator = this->lib->create_object_enumerator(this->lib, this->session,
+                                                                       tmpl, countof(tmpl), attr, countof(attr));
+       if (enumerator->enumerate(enumerator, &object))
+       {
+               data = chunk_clone(chunk_create(attr[0].pValue, attr[0].ulValueLen));
+       }
+       enumerator->destroy(enumerator);
+
+       if (data.ptr)
+       {
+               cert = lib->creds->create(lib->creds, CRED_CERTIFICATE, CERT_X509,
+                                                                 BUILD_BLOB_ASN1_DER, data, BUILD_END);
+               free(data.ptr);
+               if (cert)
+               {
+                       key = cert->get_public_key(cert);
+                       cert->destroy(cert);
+               }
+       }
+       return key;
+}
+
 /**
  * See header.
  */
@@ -507,7 +830,7 @@ pkcs11_private_key_t *pkcs11_private_key_connect(key_type_t type, va_list args)
 {
        private_pkcs11_private_key_t *this;
        char *module = NULL;
-       chunk_t keyid = chunk_empty;
+       chunk_t keyid = chunk_empty, ckaid = chunk_empty;
        int slot = -1;
        CK_RV rv;
 
@@ -568,7 +891,15 @@ pkcs11_private_key_t *pkcs11_private_key_connect(key_type_t type, va_list args)
        }
        else
        {
-               this->lib = find_lib_by_keyid(keyid, &slot);
+               this->lib = find_lib_by_keyid(keyid, &slot, CKO_PUBLIC_KEY);
+               if (!this->lib)
+               {
+                       this->lib = find_lib_by_keyid(keyid, &slot, CKO_CERTIFICATE);
+               }
+               if (!this->lib)
+               {
+                       this->lib = find_lib_and_keyid_by_skid(keyid, &ckaid, &slot);
+               }
                if (!this->lib)
                {
                        DBG1(DBG_CFG, "no PKCS#11 module found having a keyid %#B", &keyid);
@@ -587,7 +918,7 @@ pkcs11_private_key_t *pkcs11_private_key_connect(key_type_t type, va_list args)
                return NULL;
        }
 
-       this->mutex = mutex_create(MUTEX_TYPE_DEFAULT);
+       this->slot = slot;
        this->keyid = identification_create_from_encoding(ID_KEY_ID, keyid);
 
        if (!login(this, slot))
@@ -596,11 +927,33 @@ pkcs11_private_key_t *pkcs11_private_key_connect(key_type_t type, va_list args)
                return NULL;
        }
 
+       if (ckaid.ptr)
+       {
+               DBG1(DBG_CFG, "using CKA_ID '%#B' for key with keyid '%#B'",
+                        &ckaid, &keyid);
+               keyid = ckaid;
+       }
+
        if (!find_key(this, keyid))
        {
+               DBG1(DBG_CFG, "did not find the key with %s '%#B'",
+                        ckaid.ptr ? "CKA_ID" : "keyid", &keyid);
                destroy(this);
                return NULL;
        }
 
+       this->pubkey = pkcs11_public_key_connect(this->lib, slot, this->type, keyid);
+       if (!this->pubkey)
+       {
+               this->pubkey = find_pubkey_in_certs(this, keyid);
+               if (!this->pubkey)
+               {
+                       DBG1(DBG_CFG, "no public key or certificate found for private key "
+                                "(%s '%#B') on '%s':%d", ckaid.ptr ? "CKA_ID" : "keyid",
+                                &keyid, module, slot);
+                       destroy(this);
+                       return NULL;
+               }
+       }
        return &this->public;
 }