]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
seekable_format: Make parallel_compression use memory properly
authorDave Vasilevsky <dave@vasilevsky.ca>
Wed, 7 May 2025 03:26:32 +0000 (23:26 -0400)
committerYann Collet <Cyan4973@users.noreply.github.com>
Thu, 8 May 2025 05:01:49 +0000 (22:01 -0700)
Previously, parallel_compression would only handle each job's results
after ALL jobs were successfully queued. This caused all src/dst
buffers to remain in memory until then!

It also polled to check whether a job completed, which is racy without
any memory barrier.

Now, we flush results as a side effect of completing a job. Completed
frames are placed in an ordered linked-list, and any eligible frames
are flushed. This may be zero or multiple frames, depending on the
order in which jobs finish.

This design also makes it simple to support streaming input, so that
is now available. Just pass `-` as the filename, and stdin/stdout will
be used for I/O.

contrib/seekable_format/examples/parallel_compression.c

index 4e06fae324125e518264614fc2657a27d66f84be..d54704c11ec2b678c8fbd9669b162cc8f5b2ad2a 100644 (file)
@@ -23,6 +23,8 @@
 
 #include "xxhash.h"
 
+#define ZSTD_MULTITHREAD 1
+#include "threading.h"
 #include "pool.h"      // use zstd thread pool for demo
 
 #include "../zstd_seekable.h"
@@ -72,127 +74,165 @@ static size_t fclose_orDie(FILE* file)
     exit(6);
 }
 
-static void fseek_orDie(FILE* file, long int offset, int origin)
-{
-    if (!fseek(file, offset, origin)) {
-        if (!fflush(file)) return;
-    }
-    /* error */
-    perror("fseek");
-    exit(7);
-}
-
-static long int ftell_orDie(FILE* file)
-{
-    long int off = ftell(file);
-    if (off != -1) return off;
-    /* error */
-    perror("ftell");
-    exit(8);
-}
+struct state {
+    FILE* fout;
+    ZSTD_pthread_mutex_t mutex;
+    size_t nextID;
+    struct job* pending;
+    ZSTD_frameLog* frameLog;
+    const int compressionLevel;
+};
 
 struct job {
-    const void* src;
+    size_t id;
+    struct job* next;
+    struct state* state;
+
+    void* src;
     size_t srcSize;
     void* dst;
     size_t dstSize;
 
     unsigned checksum;
-
-    int compressionLevel;
-    int done;
 };
 
+static void addPending_inmutex(struct state* state, struct job* job)
+{
+    struct job** p = &state->pending;
+    while (*p && (*p)->id < job->id)
+        p = &(*p)->next;
+    job->next = *p;
+    *p = job;
+}
+
+static void flushFrame(struct state* state, struct job* job)
+{
+    fwrite_orDie(job->dst, job->dstSize, state->fout);
+    free(job->dst);
+
+    size_t ret = ZSTD_seekable_logFrame(state->frameLog, job->dstSize, job->srcSize, job->checksum);
+    if (ZSTD_isError(ret)) {
+        fprintf(stderr, "ZSTD_seekable_logFrame() error : %s \n", ZSTD_getErrorName(ret));
+        exit(12);
+    }
+}
+
+static void flushPending_inmutex(struct state* state)
+{
+    while (state->pending && state->pending->id == state->nextID) {
+        struct job* p = state->pending;
+        state->pending = p->next;
+        flushFrame(state, p);
+        free(p);
+        state->nextID++;
+    }
+}
+
+static void finishFrame(struct job* job)
+{
+    struct state *state = job->state;
+    ZSTD_pthread_mutex_lock(&state->mutex);
+    addPending_inmutex(state, job);
+    flushPending_inmutex(state);
+    ZSTD_pthread_mutex_unlock(&state->mutex);
+}
+
 static void compressFrame(void* opaque)
 {
     struct job* job = opaque;
 
     job->checksum = XXH64(job->src, job->srcSize, 0);
 
-    size_t ret = ZSTD_compress(job->dst, job->dstSize, job->src, job->srcSize, job->compressionLevel);
+    size_t ret = ZSTD_compress(job->dst, job->dstSize, job->src, job->srcSize, job->state->compressionLevel);
     if (ZSTD_isError(ret)) {
         fprintf(stderr, "ZSTD_compress() error : %s \n", ZSTD_getErrorName(ret));
         exit(20);
     }
-
     job->dstSize = ret;
-    job->done = 1;
+
+    // No longer need
+    free(job->src);
+    job->src = NULL;
+
+    finishFrame(job);
 }
 
-static void compressFile_orDie(const char* fname, const char* outName, int cLevel, unsigned frameSize, int nbThreads)
+static const char* createOutFilename_orDie(const char* filename)
 {
+    size_t const inL = strlen(filename);
+    size_t const outL = inL + 5;
+    void* outSpace = malloc_orDie(outL);
+    memset(outSpace, 0, outL);
+    strcat(outSpace, filename);
+    strcat(outSpace, ".zst");
+    return (const char*)outSpace;
+}
+
+static void openInOut_orDie(const char* fname, FILE** fin, FILE** fout) {
+    if (strcmp(fname, "-") == 0) {
+        *fin = stdin;
+        *fout = stdout;
+    } else {
+        *fin = fopen_orDie(fname, "rb");
+        const char* outName = createOutFilename_orDie(fname);
+        *fout = fopen_orDie(outName, "wb");
+    }
+}
+
+static void compressFile_orDie(const char* fname, int cLevel, unsigned frameSize, int nbThreads)
+{
+    struct state state = {
+        .nextID = 0,
+        .pending = NULL,
+        .compressionLevel = cLevel,
+    };
+    ZSTD_pthread_mutex_init(&state.mutex, NULL);
+    state.frameLog = ZSTD_seekable_createFrameLog(1);
+    if (state.frameLog == NULL) { fprintf(stderr, "ZSTD_seekable_createFrameLog() failed \n"); exit(11); }
+
     POOL_ctx* pool = POOL_create(nbThreads, nbThreads);
     if (pool == NULL) { fprintf(stderr, "POOL_create() error \n"); exit(9); }
 
-    FILE* const fin  = fopen_orDie(fname, "rb");
-    FILE* const fout = fopen_orDie(outName, "wb");
+    FILE* fin;
+    openInOut_orDie(fname, &fin, &state.fout);
 
     if (ZSTD_compressBound(frameSize) > 0xFFFFFFFFU) { fprintf(stderr, "Frame size too large \n"); exit(10); }
     unsigned dstSize = ZSTD_compressBound(frameSize);
 
-
-    fseek_orDie(fin, 0, SEEK_END);
-    long int length = ftell_orDie(fin);
-    fseek_orDie(fin, 0, SEEK_SET);
-
-    size_t numFrames = (length + frameSize - 1) / frameSize;
-
-    struct job* jobs = malloc_orDie(sizeof(struct job) * numFrames);
-
-    size_t i;
-    for(i = 0; i < numFrames; i++) {
-        void* in = malloc_orDie(frameSize);
-        void* out = malloc_orDie(dstSize);
-
-        size_t inSize = fread_orDie(in, frameSize, fin);
-
-        jobs[i].src = in;
-        jobs[i].srcSize = inSize;
-        jobs[i].dst = out;
-        jobs[i].dstSize = dstSize;
-        jobs[i].compressionLevel = cLevel;
-        jobs[i].done = 0;
-        POOL_add(pool, compressFrame, &jobs[i]);
+    for (size_t id = 0; 1; id++) {
+        struct job* job = malloc_orDie(sizeof(struct job));
+        job->id = id;
+        job->next = NULL;
+        job->state = &state;
+        job->src = malloc_orDie(frameSize);
+        job->dst = malloc_orDie(dstSize);
+        job->srcSize = fread_orDie(job->src, frameSize, fin);
+        job->dstSize = dstSize; 
+        POOL_add(pool, compressFrame, job);
+        if (feof(fin))
+            break;
     }
 
-    ZSTD_frameLog* fl = ZSTD_seekable_createFrameLog(1);
-    if (fl == NULL) { fprintf(stderr, "ZSTD_seekable_createFrameLog() failed \n"); exit(11); }
-    for (i = 0; i < numFrames; i++) {
-        while (!jobs[i].done) SLEEP(5); /* wake up every 5 milliseconds to check */
-        fwrite_orDie(jobs[i].dst, jobs[i].dstSize, fout);
-        free((void*)jobs[i].src);
-        free(jobs[i].dst);
-
-        size_t ret = ZSTD_seekable_logFrame(fl, jobs[i].dstSize, jobs[i].srcSize, jobs[i].checksum);
-        if (ZSTD_isError(ret)) { fprintf(stderr, "ZSTD_seekable_logFrame() error : %s \n", ZSTD_getErrorName(ret)); }
+    POOL_joinJobs(pool);
+    if (state.pending) {
+        fprintf(stderr, "Unexpected leftover output blocks!\n");
+        exit(13);
     }
 
     {   unsigned char seekTableBuff[1024];
         ZSTD_outBuffer out = {seekTableBuff, 1024, 0};
-        while (ZSTD_seekable_writeSeekTable(fl, &out) != 0) {
-            fwrite_orDie(seekTableBuff, out.pos, fout);
+        while (ZSTD_seekable_writeSeekTable(state.frameLog, &out) != 0) {
+            fwrite_orDie(seekTableBuff, out.pos, state.fout);
             out.pos = 0;
         }
-        fwrite_orDie(seekTableBuff, out.pos, fout);
+        fwrite_orDie(seekTableBuff, out.pos, state.fout);
     }
 
-    ZSTD_seekable_freeFrameLog(fl);
-    free(jobs);
-    fclose_orDie(fout);
+    ZSTD_seekable_freeFrameLog(state.frameLog);
+    fclose_orDie(state.fout);
     fclose_orDie(fin);
 }
 
-static const char* createOutFilename_orDie(const char* filename)
-{
-    size_t const inL = strlen(filename);
-    size_t const outL = inL + 5;
-    void* outSpace = malloc_orDie(outL);
-    memset(outSpace, 0, outL);
-    strcat(outSpace, filename);
-    strcat(outSpace, ".zst");
-    return (const char*)outSpace;
-}
-
 int main(int argc, const char** argv) {
     const char* const exeName = argv[0];
     if (argc!=4) {
@@ -206,8 +246,7 @@ int main(int argc, const char** argv) {
         unsigned const frameSize = (unsigned)atoi(argv[2]);
         int const nbThreads = atoi(argv[3]);
 
-        const char* const outFileName = createOutFilename_orDie(inFileName);
-        compressFile_orDie(inFileName, outFileName, 5, frameSize, nbThreads);
+        compressFile_orDie(inFileName, 5, frameSize, nbThreads);
     }
 
     return 0;