#include <debug.h>
#include <library.h>
#include <processing/jobs/callback_job.h>
+#include <threading/condvar.h>
#include <threading/mutex.h>
#include <utils/hashtable.h>
#include <utils/linked_list.h>
rng_t *rng;
};
+/**
+ * Struct to keep track of locked IPsec SAs
+ */
+typedef struct {
+
+ /**
+ * IPsec SA
+ */
+ ipsec_sa_t *sa;
+
+ /**
+ * Set if this SA is currently in use by a thread
+ */
+ bool locked;
+
+ /**
+ * Condvar used by threads to wait for this entry
+ */
+ condvar_t *condvar;
+
+ /**
+ * Number of threads waiting for this entry
+ */
+ u_int waiting_threads;
+
+ /**
+ * Set if this entry is awaiting deletion
+ */
+ bool awaits_deletion;
+
+} ipsec_sa_entry_t;
+
/**
* Helper struct for expiration events
*/
private_ipsec_sa_mgr_t *manager;
/**
- * SA that expired
+ * Entry that expired
*/
- ipsec_sa_t *sa;
+ ipsec_sa_entry_t *entry;
/**
* 0 if this is a hard expire, otherwise the offset in s (soft->hard)
return chunk_hash(chunk_from_thing(*spi));
}
+/**
+ * Create an SA entry
+ */
+static ipsec_sa_entry_t *create_entry(ipsec_sa_t *sa)
+{
+ ipsec_sa_entry_t *this;
+
+ INIT(this,
+ .condvar = condvar_create(CONDVAR_TYPE_DEFAULT),
+ .sa = sa,
+ );
+ return this;
+}
+
+/**
+ * Destroy an SA entry
+ */
+static void destroy_entry(ipsec_sa_entry_t *entry)
+{
+ entry->condvar->destroy(entry->condvar);
+ entry->sa->destroy(entry->sa);
+ free(entry);
+}
+
+/**
+ * Makes sure an entry is safe to remove
+ * Must be called with this->mutex held.
+ *
+ * @return TRUE if entry can be removed, FALSE if entry is already
+* being removed by another thread
+ */
+static bool wait_remove_entry(private_ipsec_sa_mgr_t *this,
+ ipsec_sa_entry_t *entry)
+{
+ if (entry->awaits_deletion)
+ {
+ /* this will be deleted by another thread already */
+ return FALSE;
+ }
+ entry->awaits_deletion = TRUE;
+ while (entry->locked)
+ {
+ entry->condvar->wait(entry->condvar, this->mutex);
+ }
+ while (entry->waiting_threads > 0)
+ {
+ entry->condvar->broadcast(entry->condvar);
+ entry->condvar->wait(entry->condvar, this->mutex);
+ }
+ return TRUE;
+}
+
+/**
+ * Waits until an is available and then locks it.
+ * Must only be called with this->mutex held
+ */
+static bool wait_for_entry(private_ipsec_sa_mgr_t *this,
+ ipsec_sa_entry_t *entry)
+{
+ while (entry->locked && !entry->awaits_deletion)
+ {
+ entry->waiting_threads++;
+ entry->condvar->wait(entry->condvar, this->mutex);
+ entry->waiting_threads--;
+ }
+ if (entry->awaits_deletion)
+ {
+ /* others may still be waiting, */
+ entry->condvar->signal(entry->condvar);
+ return FALSE;
+ }
+ entry->locked = TRUE;
+ return TRUE;
+}
+
/**
* Flushes all entries
* Must be called with this->mutex held.
*/
static void flush_entries(private_ipsec_sa_mgr_t *this)
{
+ ipsec_sa_entry_t *current;
enumerator_t *enumerator;
- ipsec_sa_t *current;
DBG2(DBG_ESP, "flushing SAD");
enumerator = this->sas->create_enumerator(this->sas);
while (enumerator->enumerate(enumerator, (void**)¤t))
{
- this->sas->remove_at(this->sas, enumerator);
- current->destroy(current);
+ if (wait_remove_entry(this, current))
+ {
+ this->sas->remove_at(this->sas, enumerator);
+ destroy_entry(current);
+ }
}
enumerator->destroy(enumerator);
}
/*
* Different match functions to find SAs in the linked list
*/
-static bool match_entry_by_ptr(ipsec_sa_t *sa, ipsec_sa_t *other)
+static bool match_entry_by_ptr(ipsec_sa_entry_t *item, ipsec_sa_entry_t *entry)
+{
+ return item == entry;
+}
+
+static bool match_entry_by_sa_ptr(ipsec_sa_entry_t *item, ipsec_sa_t *sa)
{
- return sa == other;
+ return item->sa == sa;
}
-static bool match_entry_by_spi_inbound(ipsec_sa_t *sa, u_int32_t spi,
+static bool match_entry_by_spi_inbound(ipsec_sa_entry_t *item, u_int32_t spi,
bool inbound)
{
- return sa->get_spi(sa) == spi && sa->is_inbound(sa) == inbound;
+ return item->sa->get_spi(item->sa) == spi &&
+ item->sa->is_inbound(item->sa) == inbound;
}
-static bool match_entry_by_spi_src_dst(ipsec_sa_t *sa, u_int32_t spi,
+static bool match_entry_by_spi_src_dst(ipsec_sa_entry_t *item, u_int32_t spi,
host_t *src, host_t *dst)
{
- return sa->match_by_spi_src_dst(sa, spi, src, dst);
+ return item->sa->match_by_spi_src_dst(item->sa, spi, src, dst);
+}
+
+static bool match_entry_by_reqid_inbound(ipsec_sa_entry_t *item,
+ u_int32_t reqid, bool inbound)
+{
+ return item->sa->match_by_reqid(item->sa, reqid, inbound);
+}
+
+static bool match_entry_by_spi_dst(ipsec_sa_entry_t *item, u_int32_t spi,
+ host_t *dst)
+{
+ return item->sa->match_by_spi_dst(item->sa, spi, dst);
+}
+
+/**
+ * Remove an entry
+ */
+static bool remove_entry(private_ipsec_sa_mgr_t *this, ipsec_sa_entry_t *entry)
+{
+ ipsec_sa_entry_t *current;
+ enumerator_t *enumerator;
+ bool removed = FALSE;
+
+ enumerator = this->sas->create_enumerator(this->sas);
+ while (enumerator->enumerate(enumerator, (void**)¤t))
+ {
+ if (current == entry)
+ {
+ if (wait_remove_entry(this, current))
+ {
+ this->sas->remove_at(this->sas, enumerator);
+ removed = TRUE;
+ }
+ break;
+ }
+ }
+ enumerator->destroy(enumerator);
+ return removed;
}
/**
this->mutex->lock(this->mutex);
if (this->sas->find_first(this->sas, (void*)match_entry_by_ptr,
- NULL, expired->sa) == SUCCESS)
+ NULL, expired->entry) == SUCCESS)
{
u_int32_t hard_offset = expired->hard_offset;
- ipsec_sa_t *sa = expired->sa;
+ ipsec_sa_t *sa = expired->entry->sa;
ipsec->events->expire(ipsec->events, sa->get_reqid(sa),
sa->get_protocol(sa), sa->get_spi(sa),
return JOB_RESCHEDULE(hard_offset);
}
/* hard limit reached */
- this->sas->remove(this->sas, sa, NULL);
- sa->destroy(sa);
+ if (remove_entry(this, expired->entry))
+ {
+ destroy_entry(expired->entry);
+ }
}
this->mutex->unlock(this->mutex);
return JOB_REQUEUE_NONE;
* Schedule a job to handle IPsec SA expiration
*/
static void schedule_expiration(private_ipsec_sa_mgr_t *this,
- ipsec_sa_t *sa)
+ ipsec_sa_entry_t *entry)
{
- lifetime_cfg_t *lifetime = sa->get_lifetime(sa);
+ lifetime_cfg_t *lifetime = entry->sa->get_lifetime(entry->sa);
ipsec_sa_expired_t *expired;
callback_job_t *job;
u_int32_t timeout;
INIT(expired,
.manager = this,
- .sa = sa,
+ .entry = entry,
);
/* schedule a rekey first, a hard timeout will be scheduled then, if any */
u_int16_t cpi, bool encap, bool esn, bool inbound,
traffic_selector_t *src_ts, traffic_selector_t *dst_ts)
{
+ ipsec_sa_entry_t *entry;
ipsec_sa_t *sa_new;
DBG2(DBG_ESP, "adding SAD entry with SPI %.8x and reqid {%u}",
return FAILED;
}
- schedule_expiration(this, sa_new);
- this->sas->insert_last(this->sas, sa_new);
+ entry = create_entry(sa_new);
+ schedule_expiration(this, entry);
+ this->sas->insert_last(this->sas, entry);
this->mutex->unlock(this->mutex);
return SUCCESS;
private_ipsec_sa_mgr_t *this, host_t *src, host_t *dst, u_int32_t spi,
u_int8_t protocol, u_int16_t cpi, mark_t mark)
{
- ipsec_sa_t *current, *found = NULL;
+ ipsec_sa_entry_t *current, *found = NULL;
enumerator_t *enumerator;
this->mutex->lock(this->mutex);
{
if (match_entry_by_spi_src_dst(current, spi, src, dst))
{
- this->sas->remove_at(this->sas, enumerator);
- found = current;
+ if (wait_remove_entry(this, current))
+ {
+ this->sas->remove_at(this->sas, enumerator);
+ found = current;
+ }
break;
}
}
if (found)
{
DBG2(DBG_ESP, "deleted %sbound SAD entry with SPI %.8x",
- found->is_inbound(found) ? "in" : "out", ntohl(spi));
- found->destroy(found);
+ found->sa->is_inbound(found->sa) ? "in" : "out", ntohl(spi));
+ destroy_entry(found);
return SUCCESS;
}
return FAILED;
}
+METHOD(ipsec_sa_mgr_t, checkout_by_reqid, ipsec_sa_t*,
+ private_ipsec_sa_mgr_t *this, u_int32_t reqid, bool inbound)
+{
+ ipsec_sa_entry_t *entry;
+ ipsec_sa_t *sa = NULL;
+
+ this->mutex->lock(this->mutex);
+ if (this->sas->find_first(this->sas, (void*)match_entry_by_reqid_inbound,
+ (void**)&entry, reqid, inbound) == SUCCESS &&
+ wait_for_entry(this, entry))
+ {
+ sa = entry->sa;
+ }
+ this->mutex->unlock(this->mutex);
+ return sa;
+}
+
+METHOD(ipsec_sa_mgr_t, checkout_by_spi, ipsec_sa_t*,
+ private_ipsec_sa_mgr_t *this, u_int32_t spi, host_t *dst)
+{
+ ipsec_sa_entry_t *entry;
+ ipsec_sa_t *sa = NULL;
+
+ this->mutex->lock(this->mutex);
+ if (this->sas->find_first(this->sas, (void*)match_entry_by_spi_dst,
+ (void**)&entry, spi, dst) == SUCCESS &&
+ wait_for_entry(this, entry))
+ {
+ sa = entry->sa;
+ }
+ this->mutex->unlock(this->mutex);
+ return sa;
+}
+
+METHOD(ipsec_sa_mgr_t, checkin, void,
+ private_ipsec_sa_mgr_t *this, ipsec_sa_t *sa)
+{
+ ipsec_sa_entry_t *entry;
+
+ this->mutex->lock(this->mutex);
+ if (this->sas->find_first(this->sas, (void*)match_entry_by_sa_ptr,
+ (void**)&entry, sa) == SUCCESS)
+ {
+ if (entry->locked)
+ {
+ entry->locked = FALSE;
+ entry->condvar->signal(entry->condvar);
+ }
+ }
+ this->mutex->unlock(this->mutex);
+}
+
METHOD(ipsec_sa_mgr_t, flush_sas, status_t,
private_ipsec_sa_mgr_t *this)
{
.get_spi = _get_spi,
.add_sa = _add_sa,
.del_sa = _del_sa,
+ .checkout_by_spi = _checkout_by_spi,
+ .checkout_by_reqid = _checkout_by_reqid,
+ .checkin = _checkin,
.flush_sas = _flush_sas,
.destroy = _destroy,
},