]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Move XXH64_update() into worker threads 1056/head
authorNick Terrell <terrelln@fb.com>
Thu, 1 Mar 2018 04:10:44 +0000 (20:10 -0800)
committerNick Terrell <terrelln@fb.com>
Mon, 19 Mar 2018 18:08:27 +0000 (11:08 -0700)
* Computes the XXH hash in the worker threads.
* Workers get a sequence number and wait until ther number shows up. On
  error, ensures that its sequence is finished, so future threads don't
  get blocked.
* Sets up for ldm integration, which will go in the same spot.

lib/compress/zstdmt_compress.c

index 4b236900793ad07e88cac9eb7ca233bb5b709836..e5bad42100e9b2957c5dcefc8d16b3450afacd7e 100644 (file)
@@ -304,16 +304,81 @@ static void ZSTDMT_releaseCCtx(ZSTDMT_CCtxPool* pool, ZSTD_CCtx* cctx)
     ZSTD_pthread_mutex_unlock(&pool->poolMutex);
 }
 
-
-/* ------------------------------------------ */
-/* =====          Worker thread         ===== */
-/* ------------------------------------------ */
+/* ====   Serial State   ==== */
 
 typedef struct {
     void const* start;
     size_t size;
 } range_t;
 
+typedef struct {
+    ZSTD_pthread_mutex_t mutex;
+    ZSTD_pthread_cond_t cond;
+    ZSTD_CCtx_params params;
+    XXH64_state_t xxhState;
+    unsigned nextJobID;
+} serialState_t;
+
+static void ZSTDMT_serialState_reset(serialState_t* serialState, ZSTD_CCtx_params params)
+{
+    serialState->nextJobID = 0;
+    if (params.fParams.checksumFlag)
+        XXH64_reset(&serialState->xxhState, 0);
+    serialState->params = params;
+}
+
+static int ZSTDMT_serialState_init(serialState_t* serialState)
+{
+    int initError = 0;
+    initError |= ZSTD_pthread_mutex_init(&serialState->mutex, NULL);
+    initError |= ZSTD_pthread_cond_init(&serialState->cond, NULL);
+    return initError;
+}
+
+static void ZSTDMT_serialState_free(serialState_t* serialState)
+{
+    ZSTD_pthread_mutex_destroy(&serialState->mutex);
+    ZSTD_pthread_cond_destroy(&serialState->cond);
+}
+
+static void ZSTDMT_serialState_update(serialState_t* serialState, range_t src, unsigned jobID)
+{
+    /* Wait for our turn */
+    ZSTD_PTHREAD_MUTEX_LOCK(&serialState->mutex);
+    while (serialState->nextJobID < jobID) {
+        ZSTD_pthread_cond_wait(&serialState->cond, &serialState->mutex);
+    }
+    /* A future job may error and skip our job */
+    if (serialState->nextJobID == jobID) {
+        /* It is now our turn, do any processing necessary */
+        if (serialState->params.fParams.checksumFlag && src.size > 0)
+            XXH64_update(&serialState->xxhState, src.start, src.size);
+    }
+    /* Now it is the next jobs turn */
+    serialState->nextJobID++;
+    ZSTD_pthread_cond_broadcast(&serialState->cond);
+    ZSTD_pthread_mutex_unlock(&serialState->mutex);
+}
+
+static void ZSTDMT_serialState_ensureFinished(serialState_t* serialState,
+                                              unsigned jobID, size_t cSize)
+{
+    ZSTD_PTHREAD_MUTEX_LOCK(&serialState->mutex);
+    if (serialState->nextJobID <= jobID) {
+        assert(ZSTD_isError(cSize)); (void)cSize;
+        DEBUGLOG(5, "Skipping past job %u because of error", jobID);
+        serialState->nextJobID = jobID + 1;
+        ZSTD_pthread_cond_broadcast(&serialState->cond);
+    }
+    ZSTD_pthread_mutex_unlock(&serialState->mutex);
+
+}
+
+
+/* ------------------------------------------ */
+/* =====          Worker thread         ===== */
+/* ------------------------------------------ */
+
 static const range_t kNullRange = { NULL, 0 };
 
 typedef struct {
@@ -323,9 +388,11 @@ typedef struct {
     ZSTD_pthread_cond_t job_cond;        /* Thread-safe - used by mtctx and worker */
     ZSTDMT_CCtxPool* cctxPool;           /* Thread-safe - used by mtctx and (all) workers */
     ZSTDMT_bufferPool* bufPool;          /* Thread-safe - used by mtctx and (all) workers */
+    serialState_t* serial;               /* Thread-safe - used by mtctx and (all) workers */
     buffer_t dstBuff;                    /* set by worker (or mtctx), then read by worker & mtctx, then modified by mtctx => no barrier */
     range_t prefix;                      /* set by mtctx, then read by worker & mtctx => no barrier */
     range_t src;                         /* set by mtctx, then read by worker & mtctx => no barrier */
+    unsigned jobID;                      /* set by mtctx, then read by worker => no barrier */
     unsigned firstJob;                   /* set by mtctx, then read by worker => no barrier */
     unsigned lastJob;                    /* set by mtctx, then read by worker => no barrier */
     ZSTD_CCtx_params params;             /* set by mtctx, then read by worker => no barrier */
@@ -339,9 +406,13 @@ typedef struct {
 void ZSTDMT_compressionJob(void* jobDescription)
 {
     ZSTDMT_jobDescription* const job = (ZSTDMT_jobDescription*)jobDescription;
+    ZSTD_CCtx_params jobParams = job->params;   /* do not modify job->params ! copy it, modify the copy */
     ZSTD_CCtx* const cctx = ZSTDMT_getCCtx(job->cctxPool);
     buffer_t dstBuff = job->dstBuff;
 
+    /* Don't compute the checksum for chunks, but write it in the header */
+    if (job->jobID != 0) jobParams.fParams.checksumFlag = 0;
+
     /* ressources */
     if (cctx==NULL) {
         job->cSize = ERROR(memory_allocation);
@@ -358,12 +429,11 @@ void ZSTDMT_compressionJob(void* jobDescription)
 
     /* init */
     if (job->cdict) {
-        size_t const initError = ZSTD_compressBegin_advanced_internal(cctx, NULL, 0, ZSTD_dm_auto, job->cdict, job->params, job->fullFrameSize);
+        size_t const initError = ZSTD_compressBegin_advanced_internal(cctx, NULL, 0, ZSTD_dm_auto, job->cdict, jobParams, job->fullFrameSize);
         assert(job->firstJob);  /* only allowed for first job */
         if (ZSTD_isError(initError)) { job->cSize = initError; goto _endJob; }
     } else {  /* srcStart points at reloaded section */
         U64 const pledgedSrcSize = job->firstJob ? job->fullFrameSize : job->src.size;
-        ZSTD_CCtx_params jobParams = job->params;   /* do not modify job->params ! copy it, modify the copy */
         {   size_t const forceWindowError = ZSTD_CCtxParam_setParameter(&jobParams, ZSTD_p_forceMaxWindow, !job->firstJob);
             if (ZSTD_isError(forceWindowError)) {
                 job->cSize = forceWindowError;
@@ -377,6 +447,10 @@ void ZSTDMT_compressionJob(void* jobDescription)
                 job->cSize = initError;
                 goto _endJob;
     }   }   }
+
+    /* Perform serial step as early as possible */
+    ZSTDMT_serialState_update(job->serial, job->src, job->jobID);
+
     if (!job->firstJob) {  /* flush and overwrite frame header when it's not first job */
         size_t const hSize = ZSTD_compressContinue(cctx, dstBuff.start, dstBuff.capacity, job->src.start, 0);
         if (ZSTD_isError(hSize)) { job->cSize = hSize; /* save error code */ goto _endJob; }
@@ -425,6 +499,7 @@ void ZSTDMT_compressionJob(void* jobDescription)
     }   }
 
 _endJob:
+    ZSTDMT_serialState_ensureFinished(job->serial, job->jobID, job->cSize);
     if (job->prefix.size > 0)
         DEBUGLOG(5, "Finished with prefix: %zx", (size_t)job->prefix.start);
     DEBUGLOG(5, "Finished with source: %zx", (size_t)job->src.start);
@@ -475,7 +550,7 @@ struct ZSTDMT_CCtx_s {
     roundBuff_t roundBuff;
     inBuff_t inBuff;
     int jobReady;        /* 1 => one job is already prepared, but pool has shortage of workers. Don't create another one. */
-    XXH64_state_t xxhState;
+    serialState_t serial;
     unsigned singleBlockingThread;
     unsigned jobIDMask;
     unsigned doneJobID;
@@ -540,6 +615,7 @@ ZSTDMT_CCtx* ZSTDMT_createCCtx_advanced(unsigned nbWorkers, ZSTD_customMem cMem)
 {
     ZSTDMT_CCtx* mtctx;
     U32 nbJobs = nbWorkers + 2;
+    int initError;
     DEBUGLOG(3, "ZSTDMT_createCCtx_advanced (nbWorkers = %u)", nbWorkers);
 
     if (nbWorkers < 1) return NULL;
@@ -559,8 +635,9 @@ ZSTDMT_CCtx* ZSTDMT_createCCtx_advanced(unsigned nbWorkers, ZSTD_customMem cMem)
     mtctx->jobIDMask = nbJobs - 1;
     mtctx->bufPool = ZSTDMT_createBufferPool(nbWorkers, cMem);
     mtctx->cctxPool = ZSTDMT_createCCtxPool(nbWorkers, cMem);
+    initError = ZSTDMT_serialState_init(&mtctx->serial);
     mtctx->roundBuff = kNullRoundBuff;
-    if (!mtctx->factory | !mtctx->jobs | !mtctx->bufPool | !mtctx->cctxPool) {
+    if (!mtctx->factory | !mtctx->jobs | !mtctx->bufPool | !mtctx->cctxPool | initError) {
         ZSTDMT_freeCCtx(mtctx);
         return NULL;
     }
@@ -615,6 +692,7 @@ size_t ZSTDMT_freeCCtx(ZSTDMT_CCtx* mtctx)
     ZSTDMT_freeJobsTable(mtctx->jobs, mtctx->jobIDMask+1, mtctx->cMem);
     ZSTDMT_freeBufferPool(mtctx->bufPool);
     ZSTDMT_freeCCtxPool(mtctx->cctxPool);
+    ZSTDMT_serialState_free(&mtctx->serial);
     ZSTD_freeCDict(mtctx->cdictLocal);
     if (mtctx->roundBuff.buffer)
         ZSTD_free(mtctx->roundBuff.buffer, mtctx->cMem);
@@ -779,7 +857,6 @@ static size_t ZSTDMT_compress_advanced_internal(
     size_t remainingSrcSize = srcSize;
     unsigned const compressWithinDst = (dstCapacity >= ZSTD_compressBound(srcSize)) ? nbJobs : (unsigned)(dstCapacity / ZSTD_compressBound(avgJobSize));  /* presumes avgJobSize >= 256 KB, which should be the case */
     size_t frameStartPos = 0, dstBufferPos = 0;
-    XXH64_state_t xxh64;
     assert(jobParams.nbWorkers == 0);
     assert(mtctx->cctxPool->totalCCtx == params.nbWorkers);
 
@@ -795,7 +872,7 @@ static size_t ZSTDMT_compress_advanced_internal(
 
     assert(avgJobSize >= 256 KB);  /* condition for ZSTD_compressBound(A) + ZSTD_compressBound(B) <= ZSTD_compressBound(A+B), required to compress directly into Dst (no additional buffer) */
     ZSTDMT_setBufferSize(mtctx->bufPool, ZSTD_compressBound(avgJobSize) );
-    XXH64_reset(&xxh64, 0);
+    ZSTDMT_serialState_reset(&mtctx->serial, params);
 
     if (nbJobs > mtctx->jobIDMask+1) {  /* enlarge job table */
         U32 jobsTableSize = nbJobs;
@@ -825,17 +902,14 @@ static size_t ZSTDMT_compress_advanced_internal(
             mtctx->jobs[u].fullFrameSize = srcSize;
             mtctx->jobs[u].params = jobParams;
             /* do not calculate checksum within sections, but write it in header for first section */
-            if (u!=0) mtctx->jobs[u].params.fParams.checksumFlag = 0;
             mtctx->jobs[u].dstBuff = dstBuffer;
             mtctx->jobs[u].cctxPool = mtctx->cctxPool;
             mtctx->jobs[u].bufPool = mtctx->bufPool;
+            mtctx->jobs[u].serial = &mtctx->serial;
+            mtctx->jobs[u].jobID = u;
             mtctx->jobs[u].firstJob = (u==0);
             mtctx->jobs[u].lastJob = (u==nbJobs-1);
 
-            if (params.fParams.checksumFlag) {
-                XXH64_update(&xxh64, srcStart + frameStartPos, jobSize);
-            }
-
             DEBUGLOG(5, "ZSTDMT_compress_advanced_internal: posting job %u  (%u bytes)", u, (U32)jobSize);
             DEBUG_PRINTHEX(6, mtctx->jobs[u].prefix.start, 12);
             POOL_add(mtctx->factory, ZSTDMT_compressionJob, &mtctx->jobs[u]);
@@ -876,7 +950,7 @@ static size_t ZSTDMT_compress_advanced_internal(
 
         DEBUGLOG(4, "checksumFlag : %u ", params.fParams.checksumFlag);
         if (params.fParams.checksumFlag) {
-            U32 const checksum = (U32)XXH64_digest(&xxh64);
+            U32 const checksum = (U32)XXH64_digest(&mtctx->serial.xxhState);
             if (dstPos + 4 > dstCapacity) {
                 error = ERROR(dstSize_tooSmall);
             } else {
@@ -1016,7 +1090,7 @@ size_t ZSTDMT_initCStream_internal(
     mtctx->allJobsCompleted = 0;
     mtctx->consumed = 0;
     mtctx->produced = 0;
-    if (params.fParams.checksumFlag) XXH64_reset(&mtctx->xxhState, 0);
+    ZSTDMT_serialState_reset(&mtctx->serial, params);
     return 0;
 }
 
@@ -1113,21 +1187,18 @@ static size_t ZSTDMT_createCompressionJob(ZSTDMT_CCtx* mtctx, size_t srcSize, ZS
         mtctx->jobs[jobID].consumed = 0;
         mtctx->jobs[jobID].cSize = 0;
         mtctx->jobs[jobID].params = mtctx->params;
-        /* do not calculate checksum within sections, but write it in header for first section */
-        if (mtctx->nextJobID) mtctx->jobs[jobID].params.fParams.checksumFlag = 0;
         mtctx->jobs[jobID].cdict = mtctx->nextJobID==0 ? mtctx->cdict : NULL;
         mtctx->jobs[jobID].fullFrameSize = mtctx->frameContentSize;
         mtctx->jobs[jobID].dstBuff = g_nullBuffer;
         mtctx->jobs[jobID].cctxPool = mtctx->cctxPool;
         mtctx->jobs[jobID].bufPool = mtctx->bufPool;
+        mtctx->jobs[jobID].serial = &mtctx->serial;
+        mtctx->jobs[jobID].jobID = mtctx->nextJobID;
         mtctx->jobs[jobID].firstJob = (mtctx->nextJobID==0);
         mtctx->jobs[jobID].lastJob = endFrame;
         mtctx->jobs[jobID].frameChecksumNeeded = endFrame && (mtctx->nextJobID>0) && mtctx->params.fParams.checksumFlag;
         mtctx->jobs[jobID].dstFlushed = 0;
 
-        if (mtctx->params.fParams.checksumFlag && srcSize > 0)
-            XXH64_update(&mtctx->xxhState, src, srcSize);
-
         /* Update the round buffer pos and clear the input buffer to be reset */
         mtctx->roundBuff.pos += srcSize;
         mtctx->inBuff.buffer = g_nullBuffer;
@@ -1214,7 +1285,7 @@ static size_t ZSTDMT_flushProduced(ZSTDMT_CCtx* mtctx, ZSTD_outBuffer* output, u
         assert(srcConsumed <= srcSize);
         if ( (srcConsumed == srcSize)   /* job completed -> worker no longer active */
           && mtctx->jobs[wJobID].frameChecksumNeeded ) {
-            U32 const checksum = (U32)XXH64_digest(&mtctx->xxhState);
+            U32 const checksum = (U32)XXH64_digest(&mtctx->serial.xxhState);
             DEBUGLOG(4, "ZSTDMT_flushProduced: writing checksum : %08X \n", checksum);
             MEM_writeLE32((char*)mtctx->jobs[wJobID].dstBuff.start + mtctx->jobs[wJobID].cSize, checksum);
             cSize += 4;