]> git.ipfire.org Git - thirdparty/kernel/stable.git/commitdiff
test_rhashtable: remove semaphore usage
authorArnd Bergmann <arnd@arndb.de>
Sun, 16 Dec 2018 19:48:21 +0000 (20:48 +0100)
committerDavid S. Miller <davem@davemloft.net>
Tue, 18 Dec 2018 23:12:53 +0000 (15:12 -0800)
This is one of only two files that initialize a semaphore to a negative
value. We don't really need the two semaphores here at all, but can do
the same thing in more conventional and more effient way, by using a
single waitqueue and an atomic thread counter.

This gets us a little bit closer to eliminating classic semaphores from
the kernel. It also fixes a corner case where we fail to continue after
one of the threads fails to start up.

An alternative would be to use a split kthread_create()+wake_up_process()
and completely eliminate the separate synchronization.

Acked-by: Phil Sutter <phil@nwl.cc>
Signed-off-by: Arnd Bergmann <arnd@arndb.de>
Acked-by: Herbert Xu <herbert@gondor.apana.org.au>
Signed-off-by: David S. Miller <davem@davemloft.net>
lib/test_rhashtable.c

index 82ac39ce53105f2dc39d517467333b255ae218cb..6a8ac7626797854899e77afb4cff0505272dca97 100644 (file)
 #include <linux/module.h>
 #include <linux/rcupdate.h>
 #include <linux/rhashtable.h>
-#include <linux/semaphore.h>
 #include <linux/slab.h>
 #include <linux/sched.h>
 #include <linux/random.h>
 #include <linux/vmalloc.h>
+#include <linux/wait.h>
 
 #define MAX_ENTRIES    1000000
 #define TEST_INSERT_FAIL INT_MAX
@@ -112,8 +112,8 @@ static struct rhashtable_params test_rht_params_dup = {
        .automatic_shrinking = false,
 };
 
-static struct semaphore prestart_sem;
-static struct semaphore startup_sem = __SEMAPHORE_INITIALIZER(startup_sem, 0);
+static atomic_t startup_count;
+static DECLARE_WAIT_QUEUE_HEAD(startup_wait);
 
 static int insert_retry(struct rhashtable *ht, struct test_obj *obj,
                         const struct rhashtable_params params)
@@ -634,9 +634,12 @@ static int threadfunc(void *data)
        int i, step, err = 0, insert_retries = 0;
        struct thread_data *tdata = data;
 
-       up(&prestart_sem);
-       if (down_interruptible(&startup_sem))
-               pr_err("  thread[%d]: down_interruptible failed\n", tdata->id);
+       if (atomic_dec_and_test(&startup_count))
+               wake_up(&startup_wait);
+       if (wait_event_interruptible(startup_wait, atomic_read(&startup_count) == -1)) {
+               pr_err("  thread[%d]: interrupted\n", tdata->id);
+               goto out;
+       }
 
        for (i = 0; i < tdata->entries; i++) {
                tdata->objs[i].value.id = i;
@@ -755,7 +758,7 @@ static int __init test_rht_init(void)
 
        pr_info("Testing concurrent rhashtable access from %d threads\n",
                tcount);
-       sema_init(&prestart_sem, 1 - tcount);
+       atomic_set(&startup_count, tcount);
        tdata = vzalloc(array_size(tcount, sizeof(struct thread_data)));
        if (!tdata)
                return -ENOMEM;
@@ -781,15 +784,18 @@ static int __init test_rht_init(void)
                tdata[i].objs = objs + i * entries;
                tdata[i].task = kthread_run(threadfunc, &tdata[i],
                                            "rhashtable_thrad[%d]", i);
-               if (IS_ERR(tdata[i].task))
+               if (IS_ERR(tdata[i].task)) {
                        pr_err(" kthread_run failed for thread %d\n", i);
-               else
+                       atomic_dec(&startup_count);
+               } else {
                        started_threads++;
+               }
        }
-       if (down_interruptible(&prestart_sem))
-               pr_err("  down interruptible failed\n");
-       for (i = 0; i < tcount; i++)
-               up(&startup_sem);
+       if (wait_event_interruptible(startup_wait, atomic_read(&startup_count) == 0))
+               pr_err("  wait_event interruptible failed\n");
+       /* count is 0 now, set it to -1 and wake up all threads together */
+       atomic_dec(&startup_count);
+       wake_up_all(&startup_wait);
        for (i = 0; i < tcount; i++) {
                if (IS_ERR(tdata[i].task))
                        continue;