]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
ensures that sampleSizes table is large enough
authorYann Collet <cyan@fb.com>
Fri, 15 Sep 2017 22:31:31 +0000 (15:31 -0700)
committerYann Collet <cyan@fb.com>
Fri, 15 Sep 2017 22:31:31 +0000 (15:31 -0700)
as recommended by @terrelln

programs/dibio.c

index 6e86e846316caf946f32aff3c7c1d33b31f23b4e..66f7d15264da2b4873eb40a4ed3a3c06bf853898 100644 (file)
@@ -97,11 +97,16 @@ const char* DiB_getErrorName(size_t errorCode) { return ERR_getErrorName(errorCo
 *  File related operations
 **********************************************************/
 /** DiB_loadFiles() :
- *  load files listed in fileNamesTable into buffer, even if buffer is too small.
- * @return : nb of files effectively loaded into `buffer`
- * *bufferSizePtr is modified, it provides the amount data loaded within buffer */
+ *  load samples from files listed in fileNamesTable into buffer.
+ *  works even if buffer is too small to load all samples.
+ *  Also provides the size of each sample into sampleSizes table
+ *  which must be sized correctly, using DiB_fileStats().
+ * @return : nb of samples effectively loaded into `buffer`
+ * *bufferSizePtr is modified, it provides the amount data loaded within buffer.
+ *  sampleSizes is filled with the size of each sample.
+ */
 static unsigned DiB_loadFiles(void* buffer, size_t* bufferSizePtr,
-                              size_t* chunkSizes,
+                              size_t* sampleSizes, unsigned sstSize,
                               const char** fileNamesTable, unsigned nbFiles, size_t targetChunkSize,
                               unsigned displayLevel)
 {
@@ -126,8 +131,12 @@ static unsigned DiB_loadFiles(void* buffer, size_t* bufferSizePtr,
             {   size_t const readSize = fread(buff+pos, 1, toLoad, f);
                 if (readSize != toLoad) EXM_THROW(11, "Pb reading %s", fileName);
                 pos += readSize;
-                chunkSizes[nbLoadedChunks++] = toLoad;
+                sampleSizes[nbLoadedChunks++] = toLoad;
                 remainingToLoad -= targetChunkSize;
+                if (nbLoadedChunks == sstSize) { /* no more space left in sampleSizes table */
+                    fileIndex = nbFiles;  /* stop there */
+                    break;
+                }
                 if (toLoad < targetChunkSize) {
                     fseek(f, (long)(targetChunkSize - toLoad), SEEK_CUR);
         }   }   }
@@ -221,9 +230,14 @@ static void DiB_saveDict(const char* dictFileName,
 typedef struct {
     U64 totalSizeToLoad;
     unsigned oneSampleTooLarge;
-    unsigned nbChunks;
+    unsigned nbSamples;
 } fileStats;
 
+/*! DiB_fileStats() :
+ *  Given a list of files, and a chunkSize (0 == no chunk, whole files)
+ *  provides the amount of data to be loaded and the resulting nb of samples.
+ *  This is useful primarily for allocation purpose => sample buffer, and sample sizes table.
+ */
 static fileStats DiB_fileStats(const char** fileNamesTable, unsigned nbFiles, size_t chunkSize, unsigned displayLevel)
 {
     fileStats fs;
@@ -231,12 +245,12 @@ static fileStats DiB_fileStats(const char** fileNamesTable, unsigned nbFiles, si
     memset(&fs, 0, sizeof(fs));
     for (n=0; n<nbFiles; n++) {
         U64 const fileSize = UTIL_getFileSize(fileNamesTable[n]);
-        U32 const nbChunks = (U32)(chunkSize ? (fileSize + (chunkSize-1)) / chunkSize : 1);
+        U32 const nbSamples = (U32)(chunkSize ? (fileSize + (chunkSize-1)) / chunkSize : 1);
         U64 const chunkToLoad = chunkSize ? MIN(chunkSize, fileSize) : fileSize;
         size_t const cappedChunkSize = (size_t)MIN(chunkToLoad, SAMPLESIZE_MAX);
-        fs.totalSizeToLoad += cappedChunkSize * nbChunks;
+        fs.totalSizeToLoad += cappedChunkSize * nbSamples;
         fs.oneSampleTooLarge |= (chunkSize > 2*SAMPLESIZE_MAX);
-        fs.nbChunks += nbChunks;
+        fs.nbSamples += nbSamples;
     }
     DISPLAYLEVEL(4, "Preparing to load : %u KB \n", (U32)(fs.totalSizeToLoad >> 10));
     return fs;
@@ -260,12 +274,12 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
                        ZDICT_legacy_params_t *params, ZDICT_cover_params_t *coverParams,
                        int optimizeCover)
 {
-    unsigned displayLevel = params ? params->zParams.notificationLevel :
-                            coverParams ? coverParams->zParams.notificationLevel :
-                            0;   /* should never happen */
+    unsigned const displayLevel = params ? params->zParams.notificationLevel :
+                        coverParams ? coverParams->zParams.notificationLevel :
+                        0;   /* should never happen */
     void* const dictBuffer = malloc(maxDictSize);
     fileStats const fs = DiB_fileStats(fileNamesTable, nbFiles, chunkSize, displayLevel);
-    size_t* const chunkSizes = (size_t*)malloc(fs.nbChunks * sizeof(size_t));
+    size_t* const sampleSizes = (size_t*)malloc(fs.nbSamples * sizeof(size_t));
     size_t const memMult = params ? MEMMULT : COVER_MEMMULT;
     size_t const maxMem =  DiB_findMaxMem(fs.totalSizeToLoad * memMult) / memMult;
     size_t loadedSize = (size_t) MIN ((unsigned long long)maxMem, fs.totalSizeToLoad);
@@ -273,14 +287,14 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
     int result = 0;
 
     /* Checks */
-    if ((!chunkSizes) || (!srcBuffer) || (!dictBuffer))
+    if ((!sampleSizes) || (!srcBuffer) || (!dictBuffer))
         EXM_THROW(12, "not enough memory for DiB_trainFiles");   /* should not happen */
     if (fs.oneSampleTooLarge) {
         DISPLAYLEVEL(2, "!  Warning : some sample(s) are very large \n");
         DISPLAYLEVEL(2, "!  Note that dictionary is only useful for small samples. \n");
         DISPLAYLEVEL(2, "!  As a consequence, only the first %u bytes of each sample are loaded \n", SAMPLESIZE_MAX);
     }
-    if (fs.nbChunks < 5) {
+    if (fs.nbSamples < 5) {
         DISPLAYLEVEL(2, "!  Warning : nb of samples too low for proper processing ! \n");
         DISPLAYLEVEL(2, "!  Please provide _one file per sample_. \n");
         EXM_THROW(14, "nb of samples too low");   /* we now clearly forbid this case */
@@ -297,24 +311,24 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
     /* Load input buffer */
     DISPLAYLEVEL(3, "Shuffling input files\n");
     DiB_shuffle(fileNamesTable, nbFiles);
-    nbFiles = DiB_loadFiles(srcBuffer, &loadedSize, chunkSizes, fileNamesTable, nbFiles, chunkSize, displayLevel);
+    nbFiles = DiB_loadFiles(srcBuffer, &loadedSize, sampleSizes, fs.nbSamples, fileNamesTable, nbFiles, chunkSize, displayLevel);
 
     {   size_t dictSize;
         if (params) {
             DiB_fillNoise((char*)srcBuffer + loadedSize, NOISELENGTH);   /* guard band, for end of buffer condition */
             dictSize = ZDICT_trainFromBuffer_unsafe_legacy(dictBuffer, maxDictSize,
-                                                           srcBuffer, chunkSizes, fs.nbChunks,
+                                                           srcBuffer, sampleSizes, fs.nbSamples,
                                                            *params);
         } else if (optimizeCover) {
             dictSize = ZDICT_optimizeTrainFromBuffer_cover(dictBuffer, maxDictSize,
-                                                           srcBuffer, chunkSizes, fs.nbChunks,
+                                                           srcBuffer, sampleSizes, fs.nbSamples,
                                                            coverParams);
             if (!ZDICT_isError(dictSize)) {
                 DISPLAYLEVEL(2, "k=%u\nd=%u\nsteps=%u\n", coverParams->k, coverParams->d, coverParams->steps);
             }
         } else {
             dictSize = ZDICT_trainFromBuffer_cover(dictBuffer, maxDictSize, srcBuffer,
-                                                   chunkSizes, fs.nbChunks, *coverParams);
+                                                   sampleSizes, fs.nbSamples, *coverParams);
         }
         if (ZDICT_isError(dictSize)) {
             DISPLAYLEVEL(1, "dictionary training failed : %s \n", ZDICT_getErrorName(dictSize));   /* should not happen */
@@ -329,7 +343,7 @@ int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
     /* clean up */
 _cleanup:
     free(srcBuffer);
-    free(chunkSizes);
+    free(sampleSizes);
     free(dictBuffer);
     return result;
 }