]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Add threadPool unit tests to fuzzer.c (#2604)
authorsen <senhuang96@fb.com>
Fri, 7 May 2021 15:13:44 +0000 (11:13 -0400)
committerGitHub <noreply@github.com>
Fri, 7 May 2021 15:13:44 +0000 (11:13 -0400)
tests/fuzzer.c

index 46147a24bee420c9445c699f7a87a86d3f61ea31..b4b2b6f75f1b40edec063adaa8e0609f053e7923 100644 (file)
@@ -42,6 +42,7 @@
 #include "timefn.h"       /* SEC_TO_MICRO, UTIL_time_t, UTIL_TIME_INITIALIZER, UTIL_clockSpanMicro, UTIL_getTime */
 /* must be included after util.h, due to ERROR macro redefinition issue on Visual Studio */
 #include "zstd_internal.h"  /* ZSTD_WORKSPACETOOLARGE_MAXDURATION, ZSTD_WORKSPACETOOLARGE_FACTOR, KB, MB */
+#include "threading.h"    /* ZSTD_pthread_create, ZSTD_pthread_join */
 
 
 /*-************************************
@@ -335,6 +336,126 @@ static void FUZ_decodeSequences(BYTE* dst, ZSTD_Sequence* seqs, size_t seqsSize,
     }
 }
 
+#ifdef ZSTD_MULTITHREAD
+typedef struct {
+    ZSTD_CCtx* cctx;
+    ZSTD_threadPool* pool;
+    void* const CNBuffer;
+    size_t CNBuffSize;
+    void* const compressedBuffer;
+    size_t compressedBufferSize;
+    void* const decodedBuffer;
+    int err;
+} threadPoolTests_compressionJob_payload;
+
+static void* threadPoolTests_compressionJob(void* payload) {
+    threadPoolTests_compressionJob_payload* args = (threadPoolTests_compressionJob_payload*)payload;
+    size_t cSize;
+    if (ZSTD_isError(ZSTD_CCtx_refThreadPool(args->cctx, args->pool))) args->err = 1;
+    cSize = ZSTD_compress2(args->cctx, args->compressedBuffer, args->compressedBufferSize, args->CNBuffer, args->CNBuffSize);
+    if (ZSTD_isError(cSize)) args->err = 1;
+    if (ZSTD_isError(ZSTD_decompress(args->decodedBuffer, args->CNBuffSize, args->compressedBuffer, cSize))) args->err = 1;
+    return payload;
+}
+
+static int threadPoolTests(void) {
+    int testResult = 0;
+    size_t err;
+
+    size_t const CNBuffSize = 5 MB;
+    void* const CNBuffer = malloc(CNBuffSize);
+    size_t const compressedBufferSize = ZSTD_compressBound(CNBuffSize);
+    void* const compressedBuffer = malloc(compressedBufferSize);
+    void* const decodedBuffer = malloc(CNBuffSize);
+
+    size_t const kPoolNumThreads = 8;
+
+    RDG_genBuffer(CNBuffer, CNBuffSize, 0.5, 0.5, 0);
+
+    DISPLAYLEVEL(3, "thread pool test : threadPool re-use roundtrips: ");
+    {
+        ZSTD_CCtx* cctx = ZSTD_createCCtx();
+        ZSTD_threadPool* pool = ZSTD_createThreadPool(kPoolNumThreads);
+
+        size_t nbThreads = 1;
+        for (; nbThreads <= kPoolNumThreads; ++nbThreads) {
+            ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
+            ZSTD_CCtx_setParameter(cctx, ZSTD_c_nbWorkers, (int)nbThreads);
+            err = ZSTD_CCtx_refThreadPool(cctx, pool);
+            if (ZSTD_isError(err)) {
+                DISPLAYLEVEL(3, "refThreadPool error!\n");
+                ZSTD_freeCCtx(cctx);
+                goto _output_error;
+            }
+            err = ZSTD_compress2(cctx, compressedBuffer, compressedBufferSize, CNBuffer, CNBuffSize);
+            if (ZSTD_isError(err)) {
+                DISPLAYLEVEL(3, "Compression error!\n");
+                ZSTD_freeCCtx(cctx);
+                goto _output_error;
+            }
+            err = ZSTD_decompress(decodedBuffer, CNBuffSize, compressedBuffer, err);
+            if (ZSTD_isError(err)) {
+                DISPLAYLEVEL(3, "Decompression error!\n");
+                ZSTD_freeCCtx(cctx);
+                goto _output_error;
+            }
+        }
+
+        ZSTD_freeCCtx(cctx);
+        ZSTD_freeThreadPool(pool);
+    }
+    DISPLAYLEVEL(3, "OK \n");
+
+    DISPLAYLEVEL(3, "thread pool test : threadPool simultaneous usage: ");
+    {
+        void* const decodedBuffer2 = malloc(CNBuffSize);
+        void* const compressedBuffer2 = malloc(compressedBufferSize);
+        ZSTD_threadPool* pool = ZSTD_createThreadPool(kPoolNumThreads);
+        ZSTD_CCtx* cctx1 = ZSTD_createCCtx();
+        ZSTD_CCtx* cctx2 = ZSTD_createCCtx();
+
+        ZSTD_pthread_t t1;
+        ZSTD_pthread_t t2;
+        threadPoolTests_compressionJob_payload p1 = {cctx1, pool, CNBuffer, CNBuffSize,
+                                                     compressedBuffer, compressedBufferSize, decodedBuffer, 0 /* err */};
+        threadPoolTests_compressionJob_payload p2 = {cctx2, pool, CNBuffer, CNBuffSize,
+                                                     compressedBuffer2, compressedBufferSize, decodedBuffer2, 0 /* err */};
+
+        ZSTD_CCtx_setParameter(cctx1, ZSTD_c_nbWorkers, 2);
+        ZSTD_CCtx_setParameter(cctx2, ZSTD_c_nbWorkers, 2);
+        ZSTD_CCtx_refThreadPool(cctx1, pool);
+        ZSTD_CCtx_refThreadPool(cctx2, pool);
+
+        ZSTD_pthread_create(&t1, NULL, threadPoolTests_compressionJob, &p1);
+        ZSTD_pthread_create(&t2, NULL, threadPoolTests_compressionJob, &p2);
+        ZSTD_pthread_join(t1, NULL);
+        ZSTD_pthread_join(t2, NULL);
+
+        assert(!memcmp(decodedBuffer, decodedBuffer2, CNBuffSize));
+        free(decodedBuffer2);
+        free(compressedBuffer2);
+
+        ZSTD_freeThreadPool(pool);
+        ZSTD_freeCCtx(cctx1);
+        ZSTD_freeCCtx(cctx2);
+
+        if (p1.err || p2.err) goto _output_error;
+    }
+    DISPLAYLEVEL(3, "OK \n");
+
+_end:
+    free(CNBuffer);
+    free(compressedBuffer);
+    free(decodedBuffer);
+    return testResult;
+
+_output_error:
+    testResult = 1;
+    DISPLAY("Error detected in Unit tests ! \n");
+    goto _end;
+}
+#endif /* ZSTD_MULTITHREAD */
+
 /*=============================================
 *   Unit tests
 =============================================*/
@@ -3292,7 +3413,16 @@ static int basicUnitTests(U32 const seed, double compressibility)
     }
     DISPLAYLEVEL(3, "OK \n");
 
-#endif
+    DISPLAYLEVEL(3, "test%3i : thread pool API tests : \n", testNb++)
+    {
+        int const threadPoolTestResult = threadPoolTests();
+        if (threadPoolTestResult) {
+            goto _output_error;
+        }
+    }
+    DISPLAYLEVEL(3, "thread pool tests OK \n");
+
+#endif /* ZSTD_MULTITHREAD */
 
 _end:
     free(CNBuffer);