]> git.ipfire.org Git - thirdparty/strongswan.git/commitdiff
hashtable: Optionally sort keys/items in buckets in a specific way
authorTobias Brunner <tobias@strongswan.org>
Fri, 24 Apr 2020 06:50:24 +0000 (08:50 +0200)
committerTobias Brunner <tobias@strongswan.org>
Mon, 20 Jul 2020 11:50:11 +0000 (13:50 +0200)
This can improve negative lookups, but is mostly intended to be used
with get_match() so keys/items can be matched/enumerated in a specific
order.  It's like storing sorted linked lists under a shared key but
with less memory overhead.

src/libstrongswan/collections/hashtable.c
src/libstrongswan/collections/hashtable.h
src/libstrongswan/tests/suites/test_hashtable.c

index 4607fad39d79d9c986e756e01195f2c8310877f0..28f2bef8d6a476fba5b9ceb9e7bd6bf4952cc3c3 100644 (file)
 
 #include <utils/chunk.h>
 
-/** The minimum capacity of the hash table (MUST be a power of 2) */
-#define MIN_CAPACITY 8
-/** The maximum capacity of the hash table (MUST be a power of 2) */
-#define MAX_CAPACITY (1 << 30)
+/** The minimum size of the hash table (MUST be a power of 2) */
+#define MIN_SIZE 8
+/** The maximum size of the hash table (MUST be a power of 2) */
+#define MAX_SIZE (1 << 30)
 /** Maximum load factor before the hash table is resized */
 #define LOAD_FACTOR 0.75f
 
@@ -72,7 +72,6 @@ typedef struct private_hashtable_t private_hashtable_t;
 
 /**
  * Private data of a hashtable_t object.
- *
  */
 struct private_hashtable_t {
 
@@ -87,12 +86,12 @@ struct private_hashtable_t {
        u_int count;
 
        /**
-        * The current capacity of the hash table (always a power of 2).
+        * The current size of the hash table (always a power of 2).
         */
-       u_int capacity;
+       u_int size;
 
        /**
-        * The current mask to calculate the row index (capacity - 1).
+        * The current mask to calculate the row index (size - 1).
         */
        u_int mask;
 
@@ -110,6 +109,11 @@ struct private_hashtable_t {
         * The equality function.
         */
        hashtable_equals_t equals;
+
+       /**
+        * Alternative comparison function.
+        */
+       hashtable_cmp_t cmp;
 };
 
 typedef struct private_enumerator_t private_enumerator_t;
@@ -203,13 +207,13 @@ static u_int get_nearest_powerof2(u_int n)
 /**
  * Init hash table parameters
  */
-static void init_hashtable(private_hashtable_t *this, u_int capacity)
+static void init_hashtable(private_hashtable_t *this, u_int size)
 {
-       capacity = max(MIN_CAPACITY, min(capacity, MAX_CAPACITY));
-       this->capacity = get_nearest_powerof2(capacity);
-       this->mask = this->capacity - 1;
+       size = max(MIN_SIZE, min(size, MAX_SIZE));
+       this->size = get_nearest_powerof2(size);
+       this->mask = this->size - 1;
 
-       this->table = calloc(this->capacity, sizeof(pair_t*));
+       this->table = calloc(this->size, sizeof(pair_t*));
 }
 
 /**
@@ -218,39 +222,48 @@ static void init_hashtable(private_hashtable_t *this, u_int capacity)
 static void rehash(private_hashtable_t *this)
 {
        pair_t **old_table, *to_move, *pair, *next;
-       u_int row, new_row, old_capacity;
+       u_int row, new_row, old_size;
 
-       if (this->capacity >= MAX_CAPACITY)
+       if (this->size >= MAX_SIZE)
        {
                return;
        }
 
-       old_capacity = this->capacity;
+       old_size = this->size;
        old_table = this->table;
 
-       init_hashtable(this, old_capacity << 1);
+       init_hashtable(this, old_size << 1);
 
-       for (row = 0; row < old_capacity; row++)
+       for (row = 0; row < old_size; row++)
        {
                to_move = old_table[row];
                while (to_move)
                {
+                       pair_t *prev = NULL;
+
                        new_row = to_move->hash & this->mask;
                        pair = this->table[new_row];
-                       if (pair)
+                       while (pair)
                        {
-                               while (pair->next)
+                               if (this->cmp && this->cmp(to_move->key, pair->key) < 0)
                                {
-                                       pair = pair->next;
+                                       break;
                                }
-                               pair->next = to_move;
+                               prev = pair;
+                               pair = pair->next;
+                       }
+                       next = to_move->next;
+                       to_move->next = NULL;
+                       if (prev)
+                       {
+                               to_move->next = prev->next;
+                               prev->next = to_move;
                        }
                        else
                        {
+                               to_move->next = this->table[new_row];
                                this->table[new_row] = to_move;
                        }
-                       next = to_move->next;
-                       to_move->next = NULL;
                        to_move = next;
                }
        }
@@ -266,6 +279,7 @@ static inline pair_t *find_key(private_hashtable_t *this, const void *key,
                                                           pair_t **out_prev)
 {
        pair_t *pair, *prev = NULL;
+       bool use_callback = equals != NULL;
        u_int hash;
 
        if (!this->count && !out_hash)
@@ -273,6 +287,7 @@ static inline pair_t *find_key(private_hashtable_t *this, const void *key,
                return NULL;
        }
 
+       equals = equals ?: this->equals;
        hash = this->hash(key);
        if (out_hash)
        {
@@ -282,7 +297,23 @@ static inline pair_t *find_key(private_hashtable_t *this, const void *key,
        pair = this->table[hash & this->mask];
        while (pair)
        {
-               if (hash == pair->hash && equals(key, pair->key))
+               /* when keys are ordered, we compare all items so we can abort earlier
+                * even if the hash does not match, but only as long as we don't
+                * have a callback */
+               if (!use_callback && this->cmp)
+               {
+                       int cmp = this->cmp(key, pair->key);
+                       if (cmp == 0)
+                       {
+                               break;
+                       }
+                       else if (cmp < 0)
+                       {       /* no need to continue as the key we search is smaller */
+                               pair = NULL;
+                               break;
+                       }
+               }
+               else if (hash == pair->hash && equals(key, pair->key))
                {
                        break;
                }
@@ -303,12 +334,12 @@ METHOD(hashtable_t, put, void*,
        pair_t *pair, *prev = NULL;
        u_int hash;
 
-       if (this->count >= this->capacity * LOAD_FACTOR)
+       if (this->count >= this->size * LOAD_FACTOR)
        {
                rehash(this);
        }
 
-       pair = find_key(this, key, this->equals, &hash, &prev);
+       pair = find_key(this, key, NULL, &hash, &prev);
        if (pair)
        {
                old_value = pair->value;
@@ -320,12 +351,13 @@ METHOD(hashtable_t, put, void*,
                pair = pair_create(key, value, hash);
                if (prev)
                {
+                       pair->next = prev->next;
                        prev->next = pair;
                }
                else
                {
+                       pair->next = this->table[hash & this->mask];
                        this->table[hash & this->mask] = pair;
-
                }
                this->count++;
        }
@@ -336,7 +368,7 @@ METHOD(hashtable_t, put, void*,
 METHOD(hashtable_t, get, void*,
        private_hashtable_t *this, const void *key)
 {
-       pair_t *pair = find_key(this, key, this->equals, NULL, NULL);
+       pair_t *pair = find_key(this, key, NULL, NULL, NULL);
        return pair ? pair->value : NULL;
 }
 
@@ -353,7 +385,7 @@ METHOD(hashtable_t, remove_, void*,
        void *value = NULL;
        pair_t *pair, *prev = NULL;
 
-       pair = find_key(this, key, this->equals, NULL, &prev);
+       pair = find_key(this, key, NULL, NULL, &prev);
        if (pair)
        {
                if (prev)
@@ -405,7 +437,7 @@ METHOD(enumerator_t, enumerate, bool,
 
        VA_ARGS_VGET(args, key, value);
 
-       while (this->count && this->row < this->table->capacity)
+       while (this->count && this->row < this->table->size)
        {
                this->prev = this->current;
                if (this->current)
@@ -458,7 +490,7 @@ static void destroy_internal(private_hashtable_t *this,
        pair_t *pair, *next;
        u_int row;
 
-       for (row = 0; row < this->capacity; row++)
+       for (row = 0; row < this->size; row++)
        {
                pair = this->table[row];
                while (pair)
@@ -488,11 +520,11 @@ METHOD(hashtable_t, destroy_function, void,
        destroy_internal(this, fn);
 }
 
-/*
- * Described in header.
+/**
+ * Create a hash table
  */
-hashtable_t *hashtable_create(hashtable_hash_t hash, hashtable_equals_t equals,
-                                                         u_int capacity)
+static private_hashtable_t *hashtable_create_internal(hashtable_hash_t hash,
+                                                                                                         u_int size)
 {
        private_hashtable_t *this;
 
@@ -509,10 +541,35 @@ hashtable_t *hashtable_create(hashtable_hash_t hash, hashtable_equals_t equals,
                        .destroy_function = _destroy_function,
                },
                .hash = hash,
-               .equals = equals,
        );
 
-       init_hashtable(this, capacity);
+       init_hashtable(this, size);
+
+       return this;
+}
+
+/*
+ * Described in header
+ */
+hashtable_t *hashtable_create(hashtable_hash_t hash, hashtable_equals_t equals,
+                                                         u_int size)
+{
+       private_hashtable_t *this = hashtable_create_internal(hash, size);
+
+       this->equals = equals;
+
+       return &this->public;
+}
+
+/*
+ * Described in header
+ */
+hashtable_t *hashtable_create_sorted(hashtable_hash_t hash,
+                                                                        hashtable_cmp_t cmp, u_int size)
+{
+       private_hashtable_t *this = hashtable_create_internal(hash, size);
+
+       this->cmp = cmp;
 
        return &this->public;
 }
index ab217ea11a68f0d62602b9ffcb60aef5ca6ec13d..e644c461a20b2e6a94cd68fde9f83eb61fb7f00b 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2008-2012 Tobias Brunner
+ * Copyright (C) 2008-2020 Tobias Brunner
  * HSR Hochschule fuer Technik Rapperswil
  *
  * This program is free software; you can redistribute it and/or modify it
@@ -76,10 +76,23 @@ bool hashtable_equals_ptr(const void *key, const void *other_key);
  */
 bool hashtable_equals_str(const void *key, const void *other_key);
 
+/**
+ * Prototype for a function that compares the two keys in order to sort them.
+ *
+ * @param key                  first key (the one we are looking for/inserting)
+ * @param other_key            second key
+ * @return                             less than, equal to, or greater than 0 if key is
+ *                                             less than, equal to, or greater than other_key
+ */
+typedef int (*hashtable_cmp_t)(const void *key, const void *other_key);
+
 /**
  * Class implementing a hash table.
  *
  * General purpose hash table. This hash table is not synchronized.
+ *
+ * @note Any ordering only pertains to keys/items in the same bucket (with or
+ * without the same hash value), not to the order when enumerating.
  */
 struct hashtable_t {
 
@@ -88,7 +101,7 @@ struct hashtable_t {
         *
         * @return                      enumerator over (void *key, void *value)
         */
-       enumerator_t *(*create_enumerator) (hashtable_t *this);
+       enumerator_t *(*create_enumerator)(hashtable_t *this);
 
        /**
         * Adds the given value with the given key to the hash table, if there
@@ -100,7 +113,7 @@ struct hashtable_t {
         * @param value         the value to store
         * @return                      NULL if no item was replaced, the old value otherwise
         */
-       void *(*put) (hashtable_t *this, const void *key, void *value);
+       void *(*put)(hashtable_t *this, const void *key, void *value);
 
        /**
         * Returns the value with the given key, if the hash table contains such an
@@ -109,24 +122,26 @@ struct hashtable_t {
         * @param key           the key of the requested value
         * @return                      the value, NULL if not found
         */
-       void *(*get) (hashtable_t *this, const void *key);
+       void *(*get)(hashtable_t *this, const void *key);
 
        /**
-        * Returns the value with a matching key, if the hash table contains such an
-        * entry, otherwise NULL is returned.
+        * Returns the first value with a matching key, if the hash table contains
+        * such an entry, otherwise NULL is returned.
         *
         * Compared to get() the given match function is used to compare the keys
-        * for equality.  The hash function does have to be devised properly in
+        * for equality.  The hash function does have to be devised specially in
         * order to make this work if the match function compares keys differently
-        * than the equals function provided to the constructor.  This basically
-        * allows to enumerate all entries with the same hash value.
+        * than the equals/comparison function provided to the constructor.
+        *
+        * This basically allows to enumerate all entries with the same hash value
+        * in their key's order.
         *
         * @param key           the key to match against
         * @param match         match function to be used when comparing keys
         * @return                      the value, NULL if not found
         */
-       void *(*get_match) (hashtable_t *this, const void *key,
-                                               hashtable_equals_t match);
+       void *(*get_match)(hashtable_t *this, const void *key,
+                                          hashtable_equals_t match);
 
        /**
         * Removes the value with the given key from the hash table and returns the
@@ -135,7 +150,7 @@ struct hashtable_t {
         * @param key           the key of the value to remove
         * @return                      the removed value, NULL if not found
         */
-       void *(*remove) (hashtable_t *this, const void *key);
+       void *(*remove)(hashtable_t *this, const void *key);
 
        /**
         * Removes the key and value pair from the hash table at which the given
@@ -143,19 +158,19 @@ struct hashtable_t {
         *
         * @param enumerator    enumerator, from create_enumerator
         */
-       void (*remove_at) (hashtable_t *this, enumerator_t *enumerator);
+       void (*remove_at)(hashtable_t *this, enumerator_t *enumerator);
 
        /**
         * Gets the number of items in the hash table.
         *
         * @return                      number of items
         */
-       u_int (*get_count) (hashtable_t *this);
+       u_int (*get_count)(hashtable_t *this);
 
        /**
         * Destroys a hash table object.
         */
-       void (*destroy) (hashtable_t *this);
+       void (*destroy)(hashtable_t *this);
 
        /**
         * Destroys a hash table object and calls the given function for each
@@ -168,14 +183,27 @@ struct hashtable_t {
 };
 
 /**
- * Creates an empty hash table object.
+ * Creates an empty hash table object. Items in buckets are ordered in
+ * insertion order.
  *
  * @param hash                 hash function
  * @param equals               equals function
- * @param capacity             initial capacity
- * @return                             hashtable_t object.
+ * @param size                 initial size
+ * @return                             hashtable_t object
  */
 hashtable_t *hashtable_create(hashtable_hash_t hash, hashtable_equals_t equals,
-                                                         u_int capacity);
+                                                         u_int size);
+
+/**
+ * Creates an empty hash table object with keys in each bucket sorted according
+ * to the given comparison function.
+ *
+ * @param hash                 hash function
+ * @param cmp                  comparison function
+ * @param size                 initial size
+ * @return                             hashtable_t object.
+ */
+hashtable_t *hashtable_create_sorted(hashtable_hash_t hash,
+                                                                        hashtable_cmp_t cmp, u_int size);
 
 #endif /** HASHTABLE_H_ @}*/
index dfd38e6cf0b904e74427b2a597b39ec1c9c55533..25d85d09d38b7883e5789c935492b0d8d6d8b2fc 100644 (file)
@@ -170,6 +170,43 @@ START_TEST(test_get_match_remove)
 }
 END_TEST
 
+START_TEST(test_get_match_sorted)
+{
+       char *k1 = "key1_a", *k2 = "key2", *k3 = "key1_b", *k4 = "key1_c";
+       char *v1 = "val1", *v2 = "val2", *v3 = "val3", *value;
+
+       ht = hashtable_create_sorted((hashtable_hash_t)hash_match,
+                                                                (hashtable_cmp_t)strcmp, 0);
+
+       ht->put(ht, k3, v3);
+       ht->put(ht, k2, v2);
+       ht->put(ht, k1, v1);
+       ht->put(ht, k4, v1);
+       ht->remove(ht, k1);
+       ht->put(ht, k1, v1);
+       ck_assert_int_eq(ht->get_count(ht), 4);
+       ck_assert(streq(ht->get(ht, k1), v1));
+       ck_assert(streq(ht->get(ht, k2), v2));
+       ck_assert(streq(ht->get(ht, k3), v3));
+       ck_assert(streq(ht->get(ht, k4), v1));
+
+       value = ht->get_match(ht, k1, (hashtable_equals_t)equal_match);
+       ck_assert(value != NULL);
+       ck_assert(streq(value, v1));
+       value = ht->get_match(ht, k2, (hashtable_equals_t)equal_match);
+       ck_assert(value != NULL);
+       ck_assert(streq(value, v2));
+       value = ht->get_match(ht, k3, (hashtable_equals_t)equal_match);
+       ck_assert(value != NULL);
+       ck_assert(streq(value, v1));
+       value = ht->get_match(ht, k4, (hashtable_equals_t)equal_match);
+       ck_assert(value != NULL);
+       ck_assert(streq(value, v1));
+
+       ht->destroy(ht);
+}
+END_TEST
+
 /*******************************************************************************
  * remove
  */
@@ -206,6 +243,8 @@ START_TEST(test_remove)
        char *k1 = "key1", *k2 = "key2", *k3 = "key3";
 
        do_remove(k1, k2, k3);
+       do_remove(k3, k2, k1);
+       do_remove(k1, k3, k2);
 }
 END_TEST
 
@@ -218,6 +257,36 @@ START_TEST(test_remove_one_bucket)
                                                  (hashtable_equals_t)equals, 0);
 
        do_remove(k1, k2, k3);
+       do_remove(k3, k2, k1);
+       do_remove(k1, k3, k2);
+}
+END_TEST
+
+START_TEST(test_remove_sorted)
+{
+       char *k1 = "key1", *k2 = "key2", *k3 = "key3";
+
+       ht->destroy(ht);
+       ht = hashtable_create_sorted((hashtable_hash_t)hash,
+                                                                (hashtable_cmp_t)strcmp, 0);
+
+       do_remove(k1, k2, k3);
+       do_remove(k3, k2, k1);
+       do_remove(k1, k3, k2);
+}
+END_TEST
+
+START_TEST(test_remove_sorted_one_bucket)
+{
+       char *k1 = "key1_a", *k2 = "key1_b", *k3 = "key1_c";
+
+       ht->destroy(ht);
+       ht = hashtable_create_sorted((hashtable_hash_t)hash_match,
+                                                                (hashtable_cmp_t)strcmp, 0);
+
+       do_remove(k1, k2, k3);
+       do_remove(k3, k2, k1);
+       do_remove(k1, k3, k2);
 }
 END_TEST
 
@@ -358,6 +427,11 @@ static bool equals_int(int *key1, int *key2)
        return *key1 == *key2;
 }
 
+static int cmp_int(int *key1, int *key2)
+{
+       return *key1 - *key2;
+}
+
 START_SETUP(setup_ht_many)
 {
        ht = hashtable_create((hashtable_hash_t)hash_int,
@@ -366,6 +440,14 @@ START_SETUP(setup_ht_many)
 }
 END_SETUP
 
+START_SETUP(setup_ht_many_cmp)
+{
+       ht = hashtable_create_sorted((hashtable_hash_t)hash_int,
+                                                                (hashtable_cmp_t)cmp_int, 0);
+       ck_assert_int_eq(ht->get_count(ht), 0);
+}
+END_SETUP
+
 START_TEARDOWN(teardown_ht_many)
 {
        ht->destroy_function(ht, (void*)free);
@@ -377,35 +459,42 @@ START_TEST(test_many_items)
        u_int count = 250000;
        int i, *val, r;
 
+#define GET_VALUE(i) ({ _i == 0 ? i : (count-1-i); })
+
        for (i = 0; i < count; i++)
        {
                val = malloc_thing(int);
-               *val = i;
+               *val = GET_VALUE(i);
                ht->put(ht, val, val);
        }
        for (i = 0; i < count; i++)
        {
-               val = ht->get(ht, &i);
-               ck_assert_int_eq(i, *val);
+               r = GET_VALUE(i);
+               val = ht->get(ht, &r);
+               ck_assert_int_eq(GET_VALUE(i), *val);
        }
+       ck_assert_int_eq(count, ht->get_count(ht));
        for (i = 0; i < count; i++)
        {
-               free(ht->remove(ht, &i));
+               r = GET_VALUE(i);
+               free(ht->remove(ht, &r));
        }
+       ck_assert_int_eq(0, ht->get_count(ht));
        for (i = 0; i < count; i++)
        {
                val = malloc_thing(int);
-               *val = i;
+               *val = GET_VALUE(i);
                ht->put(ht, val, val);
        }
        for (i = 0; i < count/2; i++)
        {
                free(ht->remove(ht, &i));
        }
+       ck_assert_int_eq(count/2, ht->get_count(ht));
        for (i = 0; i < count; i++)
        {
                val = malloc_thing(int);
-               *val = i;
+               *val = GET_VALUE(i);
                free(ht->put(ht, val, val));
        }
        srandom(666);
@@ -418,6 +507,7 @@ START_TEST(test_many_items)
        {
                free(ht->remove(ht, &i));
        }
+       ck_assert_int_eq(0, ht->get_count(ht));
        for (i = 0; i < 2*count; i++)
        {
                val = malloc_thing(int);
@@ -503,12 +593,15 @@ Suite *hashtable_suite_create()
        tc = tcase_create("get_match");
        tcase_add_test(tc, test_get_match);
        tcase_add_test(tc, test_get_match_remove);
+       tcase_add_test(tc, test_get_match_sorted);
        suite_add_tcase(s, tc);
 
        tc = tcase_create("remove");
        tcase_add_checked_fixture(tc, setup_ht, teardown_ht);
        tcase_add_test(tc, test_remove);
        tcase_add_test(tc, test_remove_one_bucket);
+       tcase_add_test(tc, test_remove_sorted);
+       tcase_add_test(tc, test_remove_sorted_one_bucket);
        suite_add_tcase(s, tc);
 
        tc = tcase_create("enumerator");
@@ -525,7 +618,16 @@ Suite *hashtable_suite_create()
        tc = tcase_create("many items");
        tcase_add_checked_fixture(tc, setup_ht_many, teardown_ht_many);
        tcase_set_timeout(tc, 10);
-       tcase_add_test(tc, test_many_items);
+       tcase_add_loop_test(tc, test_many_items, 0, 2);
+       tcase_add_test(tc, test_many_lookups_success);
+       tcase_add_test(tc, test_many_lookups_failure_larger);
+       tcase_add_test(tc, test_many_lookups_failure_smaller);
+       suite_add_tcase(s, tc);
+
+       tc = tcase_create("many items sorted");
+       tcase_add_checked_fixture(tc, setup_ht_many_cmp, teardown_ht_many);
+       tcase_set_timeout(tc, 10);
+       tcase_add_loop_test(tc, test_many_items, 0, 2);
        tcase_add_test(tc, test_many_lookups_success);
        tcase_add_test(tc, test_many_lookups_failure_larger);
        tcase_add_test(tc, test_many_lookups_failure_smaller);