]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
added code to generate dictionary using finalizeDictionary
authorPaul Cruz <paulcruz74@fb.com>
Tue, 13 Jun 2017 18:54:43 +0000 (11:54 -0700)
committerPaul Cruz <paulcruz74@fb.com>
Tue, 13 Jun 2017 18:54:43 +0000 (11:54 -0700)
tests/decodecorpus.c

index 632acabb86006bd737310323c5c9c33d12249682..357d831e699b1b87ed45eb0aaa99fa46cbb809e6 100644 (file)
@@ -18,6 +18,7 @@
 #include "zstd.h"
 #include "zstd_internal.h"
 #include "mem.h"
+#include "zdict.h"
 
 // Direct access to internal compression functions is required
 #include "zstd_compress.c"
@@ -316,7 +317,8 @@ static void writeFrameHeader(U32* seed, frame_t* frame, int genDict, size_t dict
         op[pos++] = windowByte;
     }
     if(genDict) {
-        MEM_writeLE32(op + pos, (U32) dictSize);
+        MEM_writeLE32(op + pos, (U32) dictID);
+        pos += 4;
     }
     if (contentSizeFlag) {
         switch (fcsCode) {
@@ -608,7 +610,7 @@ static inline void initSeqStore(seqStore_t *seqStore) {
 
 /* Randomly generate sequence commands */
 static U32 generateSequences(U32* seed, frame_t* frame, seqStore_t* seqStore,
-                                size_t contentSize, size_t literalsSize, int genDict, size_t dictSize)
+                                size_t contentSize, size_t literalsSize, int genDict, size_t dictSize, BYTE* dictContent)
 {
     /* The total length of all the matches */
     size_t const remainingMatch = contentSize - literalsSize;
@@ -686,11 +688,17 @@ static U32 generateSequences(U32* seed, frame_t* frame, seqStore_t* seqStore,
                     repIndex = MIN(2, offsetCode + 1);
                 }
             }
-        } while (offset > (size_t)((BYTE*)srcPtr - (BYTE*)frame->srcStart) || offset == 0);
+        } while (((!genDict) && (offset > (size_t)((BYTE*)srcPtr - (BYTE*)frame->srcStart))) || offset == 0);
 
         {   size_t j;
             for (j = 0; j < matchLen; j++) {
-                *srcPtr = *(srcPtr-offset);
+                if(srcPtr-offset < frame->srcStart){
+                    /* copy from dictionary instead of literals */
+                    *srcPtr = *(dictContent + dictSize - (offset-(srcPtr-frame->srcStart)));
+                }
+                else{
+                    *srcPtr = *(srcPtr-offset);
+                }
                 srcPtr++;
             }
         }
@@ -940,7 +948,7 @@ static size_t writeSequences(U32* seed, frame_t* frame, seqStore_t* seqStorePtr,
 }
 
 static size_t writeSequencesBlock(U32* seed, frame_t* frame, size_t contentSize,
-                                  size_t literalsSize, int genDict, size_t dictSize)
+                                  size_t literalsSize, int genDict, size_t dictSize, BYTE* dictContent)
 {
     seqStore_t seqStore;
     size_t numSequences;
@@ -949,14 +957,14 @@ static size_t writeSequencesBlock(U32* seed, frame_t* frame, size_t contentSize,
     initSeqStore(&seqStore);
 
     /* randomly generate sequences */
-    numSequences = generateSequences(seed, frame, &seqStore, contentSize, literalsSize, genDict, dictSize);
+    numSequences = generateSequences(seed, frame, &seqStore, contentSize, literalsSize, genDict, dictSize, dictContent);
     /* write them out to the frame data */
     CHECKERR(writeSequences(seed, frame, &seqStore, numSequences));
 
     return numSequences;
 }
 
-static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize, int genDict, size_t dictSize)
+static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize, int genDict, size_t dictSize, BYTE* dictContent)
 {
     BYTE* const blockStart = (BYTE*)frame->data;
     size_t literalsSize;
@@ -968,7 +976,7 @@ static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize
 
     DISPLAYLEVEL(4, "   literals size: %u\n", (U32)literalsSize);
 
-    nbSeq = writeSequencesBlock(seed, frame, contentSize, literalsSize, genDict, dictSize);
+    nbSeq = writeSequencesBlock(seed, frame, contentSize, literalsSize, genDict, dictSize, dictContent);
 
     DISPLAYLEVEL(4, "   number of sequences: %u\n", (U32)nbSeq);
 
@@ -976,7 +984,7 @@ static size_t writeCompressedBlock(U32* seed, frame_t* frame, size_t contentSize
 }
 
 static void writeBlock(U32* seed, frame_t* frame, size_t contentSize,
-                       int lastBlock, int genDict, size_t dictSize)
+                       int lastBlock, int genDict, size_t dictSize, BYTE* dictContent)
 {
     int const blockTypeDesc = RAND(seed) % 8;
     size_t blockSize;
@@ -1016,7 +1024,7 @@ static void writeBlock(U32* seed, frame_t* frame, size_t contentSize,
         frame->oldStats = frame->stats;
 
         frame->data = op;
-        compressedSize = writeCompressedBlock(seed, frame, contentSize, genDict, dictSize);
+        compressedSize = writeCompressedBlock(seed, frame, contentSize, genDict, dictSize, dictContent);
         if (compressedSize > contentSize) {
             blockType = 0;
             memcpy(op, frame->src, contentSize);
@@ -1042,7 +1050,7 @@ static void writeBlock(U32* seed, frame_t* frame, size_t contentSize,
     frame->data = op;
 }
 
-static void writeBlocks(U32* seed, frame_t* frame, int genDict, size_t dictSize)
+static void writeBlocks(U32* seed, frame_t* frame, int genDict, size_t dictSize, BYTE* dictContent)
 {
     size_t contentLeft = frame->header.contentSize;
     size_t const maxBlockSize = MIN(MAX_BLOCK_SIZE, frame->header.windowSize);
@@ -1065,7 +1073,7 @@ static void writeBlocks(U32* seed, frame_t* frame, int genDict, size_t dictSize)
             }
         }
 
-        writeBlock(seed, frame, blockContentSize, lastBlock, genDict, dictSize);
+        writeBlock(seed, frame, blockContentSize, lastBlock, genDict, dictSize, dictContent);
 
         contentLeft -= blockContentSize;
         if (lastBlock) break;
@@ -1130,14 +1138,14 @@ static void initFrame(frame_t* fr)
 }
 
 /* Return the final seed */
-static U32 generateFrame(U32 seed, frame_t* fr, int genDict, size_t dictSize)
+static U32 generateFrame(U32 seed, frame_t* fr, int genDict, size_t dictSize, BYTE* dictContent)
 {
     /* generate a complete frame */
     DISPLAYLEVEL(1, "frame seed: %u\n", seed);
     initFrame(fr);
 
     writeFrameHeader(&seed, fr, genDict, dictSize);
-    writeBlocks(&seed, fr, genDict, dictSize);
+    writeBlocks(&seed, fr, genDict, dictSize, dictContent);
     writeChecksum(fr);
 
     return seed;
@@ -1224,7 +1232,7 @@ static int runTestMode(U32 seed, unsigned numFiles, unsigned const testDurationS
         else
             DISPLAYUPDATE("\r%u           ", fnum);
 
-        seed = generateFrame(seed, &fr, 0, 0);
+        seed = generateFrame(seed, &fr, 0, 0, NULL);
 
         {   size_t const r = testDecodeSimple(&fr);
             if (ZSTD_isError(r)) {
@@ -1259,7 +1267,7 @@ static int generateFile(U32 seed, const char* const path,
 
     DISPLAY("seed: %u\n", seed);
 
-    generateFrame(seed, &fr, 0, 0);
+    generateFrame(seed, &fr, 0, 0, NULL);
 
     outputBuffer(fr.dataStart, (BYTE*)fr.data - (BYTE*)fr.dataStart, path);
     if (origPath) {
@@ -1281,7 +1289,7 @@ static int generateCorpus(U32 seed, unsigned numFiles, const char* const path,
 
         DISPLAYUPDATE("\r%u/%u        ", fnum, numFiles);
 
-        seed = generateFrame(seed, &fr, 0, 0);
+        seed = generateFrame(seed, &fr, 0, 0, NULL);
 
         if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) {
             DISPLAY("Error: path too long\n");
@@ -1308,9 +1316,11 @@ static int generateCorpusWithDict(U32 seed, unsigned numFiles, const char* const
 {
     const size_t minDictSize = 8;
     char outPath[MAX_PATH];
+    BYTE* dictContent;
+    BYTE* fullDict;
     U32 dictID;
-    BYTE* dictStart;
     unsigned fnum;
+    BYTE* decompressedPtr;
     ZSTD_DCtx* dctx = ZSTD_createDCtx();
     if(snprintf(outPath, MAX_PATH, "%s/dictionary", path) + 1 > MAX_PATH) {
         DISPLAY("Error: path too long\n");
@@ -1318,37 +1328,50 @@ static int generateCorpusWithDict(U32 seed, unsigned numFiles, const char* const
     }
 
     /* Generate the dictionary randomly first */
-    if(dictSize < minDictSize){
-        DISPLAY("Error: dictionary size (%zu) is too small\n", dictSize);
+    dictContent = malloc(dictSize-400);
+    dictID = RAND(&seed);
+    fullDict = malloc(dictSize);
+    RAND_buffer(&seed, dictContent, dictSize-40);
+    {
+        /* create random samples */
+        unsigned numSamples = RAND(&seed);
+        unsigned i = 0;
+        size_t* sampleSizes = malloc(numSamples*sizeof(size_t));
+        size_t* curr = sampleSizes;
+        size_t totalSize = 0;
+        while(i < numSamples){
+            *curr = RAND(&seed) % (4 << 20);
+            totalSize += *curr;
+            curr++;
+        }
+        ZDICT_params_t zdictParams;
+        BYTE* samples = malloc(totalSize);
+        RAND_buffer(&seed, samples, totalSize);
+
+        /* set dictionary params */
+        memset(&zdictParams, 0, sizeof(zdictParams));
+        zdictParams.notificationLevel = 1;
+        zdictParams.dictID = dictID;
+        zdictParams.compressionLevel = 5;
+
+        /* finalize dictionary with random samples */
+        ZDICT_finalizeDictionary(fullDict, dictSize,
+                                    dictContent, dictSize-400,
+                                    samples, sampleSizes, numSamples,
+                                    zdictParams);
     }
-    else{
-        /* variable declaration */
-        dictStart = malloc(dictSize);
-        size_t pos = 0;
-        dictID = RAND(&seed) + 1;
-
-        /* write dictionary magic number */
-        MEM_writeLE32(dictStart + pos, ZSTD_DICT_MAGIC);
-        pos += 4;
 
-        /* write random dictionary ID */
-        MEM_writeLE32(dictStart + pos, dictID);
-        pos += 4;
-
-        /* randomly generate the rest of the dictionary */
-        RAND_buffer(&seed, dictStart + pos, dictSize-8);
-        outputBuffer(dictStart, dictSize, outPath);
-    }
 
+    decompressedPtr = malloc(MAX_DECOMPRESSED_SIZE);
     /* generate random compressed/decompressed files */
     for (fnum = 0; fnum < numFiles; fnum++) {
         frame_t fr;
         size_t returnValue;
-        BYTE* decompressedPtr = malloc(MAX_DECOMPRESSED_SIZE);
+
 
         DISPLAYUPDATE("\r%u/%u        ", fnum, numFiles);
 
-        seed = generateFrame(seed, &fr, 1, dictSize);
+        seed = generateFrame(seed, &fr, 1, dictSize, dictContent);
 
         if (snprintf(outPath, MAX_PATH, "%s/z%06u.zst", path, fnum) + 1 > MAX_PATH) {
             DISPLAY("Error: path too long\n");
@@ -1368,13 +1391,10 @@ static int generateCorpusWithDict(U32 seed, unsigned numFiles, const char* const
 
         returnValue = ZSTD_decompress_usingDict(dctx, decompressedPtr, MAX_DECOMPRESSED_SIZE,
                                                fr.srcStart, (BYTE*)fr.src - (BYTE*)fr.srcStart,
-                                               dictStart,dictSize);
+                                               fullDict, dictSize);
 
     }
 
-
-    /* write uncompressed versions of files */
-    DISPLAY("This is origPath: %s\nAnd this is numFiles: %d\n", origPath, numFiles);
     return 0;
 }