]> git.ipfire.org Git - people/ms/strongswan.git/blobdiff - src/libstrongswan/credentials/credential_manager.c
Merge branch 'ikev1-clean' into ikev1-master
[people/ms/strongswan.git] / src / libstrongswan / credentials / credential_manager.c
index b3461b810de0c3633fa677795e52b2c3e4a6e502..d54359ebf93a7a3ce8fc66fdfe91682b43dfe90c 100644 (file)
@@ -52,6 +52,11 @@ struct private_credential_manager_t {
         */
        thread_value_t *local_sets;
 
+       /**
+        * Exclusive local sets, linked_list_t with credential_set_t
+        */
+       thread_value_t *exclusive_local_sets;
+
        /**
         * trust relationship and certificate cache
         */
@@ -117,12 +122,23 @@ typedef struct {
        enumerator_t *global;
        /** enumerator over local sets */
        enumerator_t *local;
+       /** enumerator over exclusive local sets */
+       enumerator_t *exclusive;
 } sets_enumerator_t;
 
 
 METHOD(enumerator_t, sets_enumerate, bool,
        sets_enumerator_t *this, credential_set_t **set)
 {
+       if (this->exclusive)
+       {
+               if (this->exclusive->enumerate(this->exclusive, set))
+               {       /* only enumerate last added */
+                       this->exclusive->destroy(this->exclusive);
+                       this->exclusive = NULL;
+                       return TRUE;
+               }
+       }
        if (this->global)
        {
                if (this->global->enumerate(this->global, set))
@@ -145,6 +161,7 @@ METHOD(enumerator_t, sets_destroy, void,
 {
        DESTROY_IF(this->global);
        DESTROY_IF(this->local);
+       DESTROY_IF(this->exclusive);
        free(this);
 }
 
@@ -154,19 +171,28 @@ METHOD(enumerator_t, sets_destroy, void,
 static enumerator_t *create_sets_enumerator(private_credential_manager_t *this)
 {
        sets_enumerator_t *enumerator;
-       linked_list_t *local;
+       linked_list_t *list;
 
        INIT(enumerator,
                .public = {
                        .enumerate = (void*)_sets_enumerate,
                        .destroy = _sets_destroy,
                },
-               .global = this->sets->create_enumerator(this->sets),
        );
-       local = this->local_sets->get(this->local_sets);
-       if (local)
+
+       list = this->exclusive_local_sets->get(this->exclusive_local_sets);
+       if (list && list->get_count(list))
+       {
+               enumerator->exclusive = list->create_enumerator(list);
+       }
+       else
        {
-               enumerator->local = local->create_enumerator(local);
+               enumerator->global = this->sets->create_enumerator(this->sets);
+               list = this->local_sets->get(this->local_sets);
+               if (list)
+               {
+                       enumerator->local = list->create_enumerator(list);
+               }
        }
        return &enumerator->public;
 }
@@ -373,26 +399,55 @@ METHOD(credential_manager_t, get_shared, shared_key_t*,
 }
 
 METHOD(credential_manager_t, add_local_set, void,
-       private_credential_manager_t *this, credential_set_t *set)
+       private_credential_manager_t *this, credential_set_t *set, bool exclusive)
 {
        linked_list_t *sets;
+       thread_value_t *tv;
 
-       sets = this->local_sets->get(this->local_sets);
+       if (exclusive)
+       {
+               tv = this->exclusive_local_sets;
+       }
+       else
+       {
+               tv = this->local_sets;
+       }
+       sets = tv->get(tv);
        if (!sets)
-       {       /* first invocation */
+       {
                sets = linked_list_create();
-               this->local_sets->set(this->local_sets, sets);
+               tv->set(tv, sets);
+       }
+       if (exclusive)
+       {
+               sets->insert_first(sets, set);
+       }
+       else
+       {
+               sets->insert_last(sets, set);
        }
-       sets->insert_last(sets, set);
 }
 
 METHOD(credential_manager_t, remove_local_set, void,
        private_credential_manager_t *this, credential_set_t *set)
 {
        linked_list_t *sets;
+       thread_value_t *tv;
 
-       sets = this->local_sets->get(this->local_sets);
-       sets->remove(sets, set, NULL);
+       tv = this->local_sets;
+       sets = tv->get(tv);
+       if (sets && sets->remove(sets, set, NULL) && sets->get_count(sets) == 0)
+       {
+               tv->set(tv, NULL);
+               sets->destroy(sets);
+       }
+       tv = this->exclusive_local_sets;
+       sets = tv->get(tv);
+       if (sets && sets->remove(sets, set, NULL) && sets->get_count(sets) == 0)
+       {
+               tv->set(tv, NULL);
+               sets->destroy(sets);
+       }
 }
 
 METHOD(credential_manager_t, cache_cert, void,
@@ -859,7 +914,7 @@ METHOD(credential_manager_t, create_public_enumerator, enumerator_t*,
        if (auth)
        {
                enumerator->wrapper = auth_cfg_wrapper_create(auth);
-               add_local_set(this, &enumerator->wrapper->set);
+               add_local_set(this, &enumerator->wrapper->set, FALSE);
        }
        this->lock->read_lock(this->lock);
        return &enumerator->public;
@@ -992,42 +1047,45 @@ METHOD(credential_manager_t, get_private, private_key_t*,
                }
        }
 
-       /* if a specific certificate is preferred, check for a matching key */
-       cert = auth->get(auth, AUTH_RULE_SUBJECT_CERT);
-       if (cert)
+       if (auth)
        {
-               private = get_private_by_cert(this, cert, type);
-               if (private)
+               /* if a specific certificate is preferred, check for a matching key */
+               cert = auth->get(auth, AUTH_RULE_SUBJECT_CERT);
+               if (cert)
                {
-                       trustchain = build_trustchain(this, cert, auth);
-                       if (trustchain)
+                       private = get_private_by_cert(this, cert, type);
+                       if (private)
                        {
-                               auth->merge(auth, trustchain, FALSE);
-                               trustchain->destroy(trustchain);
+                               trustchain = build_trustchain(this, cert, auth);
+                               if (trustchain)
+                               {
+                                       auth->merge(auth, trustchain, FALSE);
+                                       trustchain->destroy(trustchain);
+                               }
+                               return private;
                        }
-                       return private;
                }
-       }
 
-       /* try to build a trust chain for each certificate found */
-       enumerator = create_cert_enumerator(this, CERT_ANY, type, id, FALSE);
-       while (enumerator->enumerate(enumerator, &cert))
-       {
-               private = get_private_by_cert(this, cert, type);
-               if (private)
+               /* try to build a trust chain for each certificate found */
+               enumerator = create_cert_enumerator(this, CERT_ANY, type, id, FALSE);
+               while (enumerator->enumerate(enumerator, &cert))
                {
-                       trustchain = build_trustchain(this, cert, auth);
-                       if (trustchain)
+                       private = get_private_by_cert(this, cert, type);
+                       if (private)
                        {
-                               auth->merge(auth, trustchain, FALSE);
-                               trustchain->destroy(trustchain);
-                               break;
+                               trustchain = build_trustchain(this, cert, auth);
+                               if (trustchain)
+                               {
+                                       auth->merge(auth, trustchain, FALSE);
+                                       trustchain->destroy(trustchain);
+                                       break;
+                               }
+                               private->destroy(private);
+                               private = NULL;
                        }
-                       private->destroy(private);
-                       private = NULL;
                }
+               enumerator->destroy(enumerator);
        }
-       enumerator->destroy(enumerator);
 
        /* if no valid trustchain was found, fall back to the first usable cert */
        if (!private)
@@ -1038,7 +1096,10 @@ METHOD(credential_manager_t, get_private, private_key_t*,
                        private = get_private_by_cert(this, cert, type);
                        if (private)
                        {
-                               auth->add(auth, AUTH_RULE_SUBJECT_CERT, cert->get_ref(cert));
+                               if (auth)
+                               {
+                                       auth->add(auth, AUTH_RULE_SUBJECT_CERT, cert->get_ref(cert));
+                               }
                                break;
                        }
                }
@@ -1100,6 +1161,7 @@ METHOD(credential_manager_t, destroy, void,
        this->sets->remove(this->sets, this->cache, NULL);
        this->sets->destroy(this->sets);
        this->local_sets->destroy(this->local_sets);
+       this->exclusive_local_sets->destroy(this->exclusive_local_sets);
        this->cache->destroy(this->cache);
        this->validators->destroy(this->validators);
        this->lock->destroy(this->lock);
@@ -1144,6 +1206,7 @@ credential_manager_t *credential_manager_create()
        );
 
        this->local_sets = thread_value_create((thread_cleanup_t)this->sets->destroy);
+       this->exclusive_local_sets = thread_value_create((thread_cleanup_t)this->sets->destroy);
        this->sets->insert_first(this->sets, this->cache);
 
        return &this->public;