]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Fix race condition in the Windows thread / pthread translation layer
authorYonatan Komornik <yoniko@gmail.com>
Sat, 17 Dec 2022 02:24:02 +0000 (18:24 -0800)
committerYonatan Komornik <yoniko@gmail.com>
Sat, 17 Dec 2022 21:38:02 +0000 (13:38 -0800)
When spawning a Windows thread we have small worker wrapper function that translates
between the interfaces of Windows and POSIX threads.
This wrapper is given a pointer that might get stale before the worker starts running,
resulting in UB and crashes.
This commit adds synchronization so that we know the wrapper has finished reading the data
it needs before we allow the main thread to resume execution.

lib/common/threading.c
lib/common/threading.h

index 6bbb1493734ecf1cc6d9951684be9d61990c601b..825826500a67116a51f92400bf567963fe9734df 100644 (file)
@@ -34,35 +34,72 @@ int g_ZSTD_threading_useless_symbol;
 
 /* ===  Implementation  === */
 
+typedef struct {
+    void* (*start_routine)(void*);
+    void* arg;
+    int initialized;
+    ZSTD_pthread_cond_t initialized_cond;
+    ZSTD_pthread_mutex_t initialized_mutex;
+} ZSTD_thread_params_t;
+
 static unsigned __stdcall worker(void *arg)
 {
-    ZSTD_pthread_t* const thread = (ZSTD_pthread_t*) arg;
-    thread->start_routine(thread->arg);
+    ZSTD_thread_params_t* const thread_param = (ZSTD_thread_params_t*)arg;
+    void* (*start_routine)(void*) = thread_param->start_routine;
+    void* thread_arg = thread_param->arg;
+
+    /* Signal main thread that we are running and do not depend on its memory anymore */
+    ZSTD_pthread_mutex_lock(&thread_param->initialized_mutex);
+    thread_param->initialized = 1;
+    ZSTD_pthread_mutex_unlock(&thread_param->initialized_mutex);
+    ZSTD_pthread_cond_signal(&thread_param->initialized_cond);
+
+    start_routine(thread_arg);
+
     return 0;
 }
 
 int ZSTD_pthread_create(ZSTD_pthread_t* thread, const void* unused,
             void* (*start_routine) (void*), void* arg)
 {
+    ZSTD_thread_params_t thread_param;
+    int error = 0;
     (void)unused;
-    thread->arg = arg;
-    thread->start_routine = start_routine;
-    thread->handle = (HANDLE) _beginthreadex(NULL, 0, worker, thread, 0, NULL);
-
-    if (!thread->handle)
+    thread_param.start_routine = start_routine;
+    thread_param.arg = arg;
+    thread_param.initialized = 0;
+
+    /* Setup thread initialization synchronization */
+    error |= ZSTD_pthread_cond_init(&thread_param.initialized_cond, NULL);
+    error |= ZSTD_pthread_mutex_init(&thread_param.initialized_mutex, NULL);
+    if(error)
+        return -1;
+    ZSTD_pthread_mutex_lock(&thread_param.initialized_mutex);
+
+    /* Spawn thread */
+    *thread = (HANDLE)_beginthreadex(NULL, 0, worker, &thread_param, 0, NULL);
+    if (!thread)
         return errno;
-    else
-        return 0;
+
+    /* Wait for thread to be initialized */
+    while(!thread_param.initialized) {
+        ZSTD_pthread_cond_wait(&thread_param.initialized_cond, &thread_param.initialized_mutex);
+    }
+    ZSTD_pthread_mutex_unlock(&thread_param.initialized_mutex);
+    ZSTD_pthread_mutex_destroy(&thread_param.initialized_mutex);
+    ZSTD_pthread_cond_destroy(&thread_param.initialized_cond);
+
+    return 0;
 }
 
 int ZSTD_pthread_join(ZSTD_pthread_t thread)
 {
     DWORD result;
 
-    if (!thread.handle) return 0;
+    if (!thread) return 0;
 
-    result = WaitForSingleObject(thread.handle, INFINITE);
-    CloseHandle(thread.handle);
+    result = WaitForSingleObject(thread, INFINITE);
+    CloseHandle(thread);
 
     switch (result) {
     case WAIT_OBJECT_0:
index 603d479c7fad14754239bd46dd1b06a5dcfa8745..fb5c1c8787343d6de1f1075d41e9601ca90b1f91 100644 (file)
@@ -61,11 +61,7 @@ extern "C" {
 #define ZSTD_pthread_cond_broadcast(a)  WakeAllConditionVariable((a))
 
 /* ZSTD_pthread_create() and ZSTD_pthread_join() */
-typedef struct {
-    HANDLE handle;
-    void* (*start_routine)(void*);
-    void* arg;
-} ZSTD_pthread_t;
+typedef HANDLE ZSTD_pthread_t;
 
 int ZSTD_pthread_create(ZSTD_pthread_t* thread, const void* unused,
                    void* (*start_routine) (void*), void* arg);