]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
protect buffer pool with a mutex
authorYann Collet <cyan@fb.com>
Wed, 28 Dec 2016 14:31:19 +0000 (15:31 +0100)
committerYann Collet <cyan@fb.com>
Wed, 28 Dec 2016 14:31:19 +0000 (15:31 +0100)
lib/compress/zstdmt_compress.c

index 13cc19488abd9a4fa48a870ef58ce34914e80293..c698dce0eee6ec07cd90c78997627a52feb5b26a 100644 (file)
@@ -1,5 +1,5 @@
 #include <stdlib.h>   /* malloc */
-#include <pthread.h>
+#include <pthread.h>  /* posix only, to be replaced by a more portable version */
 #include "zstd_internal.h"   /* MIN, ERROR */
 #include "zstdmt_compress.h"
 
@@ -39,7 +39,7 @@ static ZSTDMT_dstBufferManager ZSTDMT_createDstBufferManager(void* dst, size_t d
     dbm.frameIDToWrite = 0;
     pthread_mutex_init(&dbm.frameTable_mutex, NULL);
     pthread_mutex_init(&dbm.allFramesWritten_mutex, NULL);
-    pthread_mutex_lock(&dbm.allFramesWritten_mutex);
+    pthread_mutex_lock(&dbm.allFramesWritten_mutex);  /* maybe could be merged into init ? */
     dbm.nbStackedFrames = 0;
     return dbm;
 }
@@ -92,7 +92,7 @@ static size_t ZSTDMT_tryWriteFrame(ZSTDMT_dstBufferManager* dstBufferManager,
     pthread_mutex_lock(&dstBufferManager->frameTable_mutex);
     if (frameID != dstBufferManager->frameIDToWrite) {
         DEBUGLOG(4, "writing frameID %u : not possible, waiting for %u  ", frameID, dstBufferManager->frameIDToWrite);
-        frameToWrite_t frame = { src, srcSize, frameID, isLastFrame };
+        frameToWrite_t const frame = { src, srcSize, frameID, isLastFrame };
         ZSTDMT_stackFrameToWrite(dstBufferManager, frame);
         pthread_mutex_unlock(&dstBufferManager->frameTable_mutex);
         return 0;
@@ -121,9 +121,11 @@ static size_t ZSTDMT_tryWriteFrame(ZSTDMT_dstBufferManager* dstBufferManager,
         for (u=0; u<dstBufferManager->nbStackedFrames; u++) {
             if (dstBufferManager->stackedFrame[u].frameID == frameID) {
                 pthread_mutex_unlock(&dstBufferManager->frameTable_mutex);
+                DEBUGLOG(4, "catch up frame %u ", frameID);
                 { size_t const writeError = ZSTDMT_writeFrame(dstBufferManager, u);
                   if (ZSTD_isError(writeError)) return writeError; }
                 lastFrameWritten = dstBufferManager->stackedFrame[u].isLastFrame;
+                dstBufferManager->frameIDToWrite = frameID+1;
                 /* remove frame from stack */
                 pthread_mutex_lock(&dstBufferManager->frameTable_mutex);
                 dstBufferManager->stackedFrame[u] = dstBufferManager->stackedFrame[dstBufferManager->nbStackedFrames-1];
@@ -183,20 +185,24 @@ static ZSTDMT_jobDescription ZSTDMT_getjob(ZSTDMT_jobAgency* jobAgency)
 
 #define ZSTDMT_NBBUFFERSPOOLED_MAX ZSTDMT_NBTHREADS_MAX
 typedef struct ZSTDMT_bufferPool_s {
+    pthread_mutex_t bufferPool_mutex;
     buffer_t bTable[ZSTDMT_NBBUFFERSPOOLED_MAX];
     unsigned nbBuffers;
 } ZSTDMT_bufferPool;
 
 static buffer_t ZSTDMT_getBuffer(ZSTDMT_bufferPool* pool, size_t bSize)
 {
+    pthread_mutex_lock(&pool->bufferPool_mutex);
     if (pool->nbBuffers) {   /* try to use an existing buffer */
         pool->nbBuffers--;
         buffer_t const buf = pool->bTable[pool->nbBuffers];
+        pthread_mutex_unlock(&pool->bufferPool_mutex);
         size_t const availBufferSize = buf.bufferSize;
         if ((availBufferSize >= bSize) & (availBufferSize <= 10*bSize))   /* large enough, but not too much */
             return buf;
         free(buf.start);   /* size conditions not respected : create a new buffer */
     }
+    pthread_mutex_unlock(&pool->bufferPool_mutex);
     /* create new buffer */
     buffer_t buf;
     buf.bufferSize = bSize;
@@ -207,11 +213,14 @@ static buffer_t ZSTDMT_getBuffer(ZSTDMT_bufferPool* pool, size_t bSize)
 /* effectively store buffer for later re-use, up to pool capacity */
 static void ZSTDMT_releaseBuffer(ZSTDMT_bufferPool* pool, buffer_t buf)
 {
+    pthread_mutex_lock(&pool->bufferPool_mutex);
     if (pool->nbBuffers >= ZSTDMT_NBBUFFERSPOOLED_MAX) {
+        pthread_mutex_unlock(&pool->bufferPool_mutex);
         free(buf.start);
         return;
     }
     pool->bTable[pool->nbBuffers++] = buf;   /* store for later re-use */
+    pthread_mutex_unlock(&pool->bufferPool_mutex);
 }
 
 
@@ -253,9 +262,12 @@ ZSTDMT_CCtx *ZSTDMT_createCCtx(unsigned nbThreads)
     if ((nbThreads < 1) | (nbThreads > ZSTDMT_NBTHREADS_MAX)) return NULL;
     ZSTDMT_CCtx* const cctx = (ZSTDMT_CCtx*) calloc(1, sizeof(ZSTDMT_CCtx));
     if (!cctx) return NULL;
+    /* init jobAgency */
     pthread_mutex_init(&cctx->jobAgency.jobAnnounce_mutex, NULL);   /* check return value ? */
     pthread_mutex_init(&cctx->jobAgency.jobApply_mutex, NULL);
     pthread_mutex_lock(&cctx->jobAgency.jobAnnounce_mutex);   /* no job at beginning */
+    /* init bufferPool */
+    pthread_mutex_init(&cctx->bufferPool.bufferPool_mutex, NULL);
     /* start all workers */
     cctx->nbThreads = nbThreads;
     DEBUGLOG(2, "nbThreads : %u \n", nbThreads);